snowflake-ml-python 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -13
  2. snowflake/ml/_internal/exceptions/modeling_error_messages.py +5 -1
  3. snowflake/ml/_internal/telemetry.py +19 -0
  4. snowflake/ml/feature_store/__init__.py +9 -0
  5. snowflake/ml/feature_store/entity.py +73 -0
  6. snowflake/ml/feature_store/feature_store.py +1657 -0
  7. snowflake/ml/feature_store/feature_view.py +459 -0
  8. snowflake/ml/model/_client/ops/model_ops.py +16 -38
  9. snowflake/ml/model/_client/sql/model.py +1 -7
  10. snowflake/ml/model/_client/sql/model_version.py +20 -15
  11. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +9 -1
  12. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  13. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +12 -2
  14. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +7 -3
  15. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -6
  16. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +0 -2
  17. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  18. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -2
  19. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  20. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  21. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  22. snowflake/ml/model/model_signature.py +72 -16
  23. snowflake/ml/model/type_hints.py +12 -0
  24. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -41
  25. snowflake/ml/modeling/_internal/model_trainer_builder.py +13 -9
  26. snowflake/ml/modeling/_internal/{distributed_hpo_trainer.py → snowpark_implementations/distributed_hpo_trainer.py} +66 -96
  27. snowflake/ml/modeling/_internal/{snowpark_handlers.py → snowpark_implementations/snowpark_handlers.py} +9 -6
  28. snowflake/ml/modeling/_internal/{xgboost_external_memory_trainer.py → snowpark_implementations/xgboost_external_memory_trainer.py} +3 -1
  29. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +19 -3
  30. snowflake/ml/modeling/cluster/affinity_propagation.py +19 -3
  31. snowflake/ml/modeling/cluster/agglomerative_clustering.py +19 -3
  32. snowflake/ml/modeling/cluster/birch.py +19 -3
  33. snowflake/ml/modeling/cluster/bisecting_k_means.py +19 -3
  34. snowflake/ml/modeling/cluster/dbscan.py +19 -3
  35. snowflake/ml/modeling/cluster/feature_agglomeration.py +19 -3
  36. snowflake/ml/modeling/cluster/k_means.py +19 -3
  37. snowflake/ml/modeling/cluster/mean_shift.py +19 -3
  38. snowflake/ml/modeling/cluster/mini_batch_k_means.py +19 -3
  39. snowflake/ml/modeling/cluster/optics.py +19 -3
  40. snowflake/ml/modeling/cluster/spectral_biclustering.py +19 -3
  41. snowflake/ml/modeling/cluster/spectral_clustering.py +19 -3
  42. snowflake/ml/modeling/cluster/spectral_coclustering.py +19 -3
  43. snowflake/ml/modeling/compose/column_transformer.py +19 -3
  44. snowflake/ml/modeling/compose/transformed_target_regressor.py +19 -3
  45. snowflake/ml/modeling/covariance/elliptic_envelope.py +19 -3
  46. snowflake/ml/modeling/covariance/empirical_covariance.py +19 -3
  47. snowflake/ml/modeling/covariance/graphical_lasso.py +19 -3
  48. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +19 -3
  49. snowflake/ml/modeling/covariance/ledoit_wolf.py +19 -3
  50. snowflake/ml/modeling/covariance/min_cov_det.py +19 -3
  51. snowflake/ml/modeling/covariance/oas.py +19 -3
  52. snowflake/ml/modeling/covariance/shrunk_covariance.py +19 -3
  53. snowflake/ml/modeling/decomposition/dictionary_learning.py +19 -3
  54. snowflake/ml/modeling/decomposition/factor_analysis.py +19 -3
  55. snowflake/ml/modeling/decomposition/fast_ica.py +19 -3
  56. snowflake/ml/modeling/decomposition/incremental_pca.py +19 -3
  57. snowflake/ml/modeling/decomposition/kernel_pca.py +19 -3
  58. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +19 -3
  59. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +19 -3
  60. snowflake/ml/modeling/decomposition/pca.py +19 -3
  61. snowflake/ml/modeling/decomposition/sparse_pca.py +19 -3
  62. snowflake/ml/modeling/decomposition/truncated_svd.py +19 -3
  63. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +19 -3
  64. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +19 -3
  65. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +19 -3
  66. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +19 -3
  67. snowflake/ml/modeling/ensemble/bagging_classifier.py +19 -3
  68. snowflake/ml/modeling/ensemble/bagging_regressor.py +19 -3
  69. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +19 -3
  70. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +19 -3
  71. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +19 -3
  72. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +19 -3
  73. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +19 -3
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +19 -3
  75. snowflake/ml/modeling/ensemble/isolation_forest.py +19 -3
  76. snowflake/ml/modeling/ensemble/random_forest_classifier.py +19 -3
  77. snowflake/ml/modeling/ensemble/random_forest_regressor.py +19 -3
  78. snowflake/ml/modeling/ensemble/stacking_regressor.py +19 -3
  79. snowflake/ml/modeling/ensemble/voting_classifier.py +19 -3
  80. snowflake/ml/modeling/ensemble/voting_regressor.py +19 -3
  81. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +19 -3
  82. snowflake/ml/modeling/feature_selection/select_fdr.py +19 -3
  83. snowflake/ml/modeling/feature_selection/select_fpr.py +19 -3
  84. snowflake/ml/modeling/feature_selection/select_fwe.py +19 -3
  85. snowflake/ml/modeling/feature_selection/select_k_best.py +19 -3
  86. snowflake/ml/modeling/feature_selection/select_percentile.py +19 -3
  87. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +19 -3
  88. snowflake/ml/modeling/feature_selection/variance_threshold.py +19 -3
  89. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +19 -3
  90. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +19 -3
  91. snowflake/ml/modeling/impute/iterative_imputer.py +19 -3
  92. snowflake/ml/modeling/impute/knn_imputer.py +19 -3
  93. snowflake/ml/modeling/impute/missing_indicator.py +19 -3
  94. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +19 -3
  95. snowflake/ml/modeling/kernel_approximation/nystroem.py +19 -3
  96. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +19 -3
  97. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +19 -3
  98. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +19 -3
  99. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +19 -3
  100. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +19 -3
  101. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +19 -3
  102. snowflake/ml/modeling/linear_model/ard_regression.py +19 -3
  103. snowflake/ml/modeling/linear_model/bayesian_ridge.py +19 -3
  104. snowflake/ml/modeling/linear_model/elastic_net.py +19 -3
  105. snowflake/ml/modeling/linear_model/elastic_net_cv.py +19 -3
  106. snowflake/ml/modeling/linear_model/gamma_regressor.py +19 -3
  107. snowflake/ml/modeling/linear_model/huber_regressor.py +19 -3
  108. snowflake/ml/modeling/linear_model/lars.py +19 -3
  109. snowflake/ml/modeling/linear_model/lars_cv.py +19 -3
  110. snowflake/ml/modeling/linear_model/lasso.py +19 -3
  111. snowflake/ml/modeling/linear_model/lasso_cv.py +19 -3
  112. snowflake/ml/modeling/linear_model/lasso_lars.py +19 -3
  113. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +19 -3
  114. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +19 -3
  115. snowflake/ml/modeling/linear_model/linear_regression.py +19 -3
  116. snowflake/ml/modeling/linear_model/logistic_regression.py +19 -3
  117. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +19 -3
  118. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +19 -3
  119. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +19 -3
  120. snowflake/ml/modeling/linear_model/multi_task_lasso.py +19 -3
  121. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +19 -3
  122. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +19 -3
  123. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +19 -3
  124. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +19 -3
  125. snowflake/ml/modeling/linear_model/perceptron.py +19 -3
  126. snowflake/ml/modeling/linear_model/poisson_regressor.py +19 -3
  127. snowflake/ml/modeling/linear_model/ransac_regressor.py +19 -3
  128. snowflake/ml/modeling/linear_model/ridge.py +19 -3
  129. snowflake/ml/modeling/linear_model/ridge_classifier.py +19 -3
  130. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +19 -3
  131. snowflake/ml/modeling/linear_model/ridge_cv.py +19 -3
  132. snowflake/ml/modeling/linear_model/sgd_classifier.py +19 -3
  133. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +19 -3
  134. snowflake/ml/modeling/linear_model/sgd_regressor.py +19 -3
  135. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +19 -3
  136. snowflake/ml/modeling/linear_model/tweedie_regressor.py +19 -3
  137. snowflake/ml/modeling/manifold/isomap.py +19 -3
  138. snowflake/ml/modeling/manifold/mds.py +19 -3
  139. snowflake/ml/modeling/manifold/spectral_embedding.py +19 -3
  140. snowflake/ml/modeling/manifold/tsne.py +19 -3
  141. snowflake/ml/modeling/metrics/classification.py +5 -6
  142. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  143. snowflake/ml/modeling/metrics/ranking.py +7 -3
  144. snowflake/ml/modeling/metrics/regression.py +6 -3
  145. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +19 -3
  146. snowflake/ml/modeling/mixture/gaussian_mixture.py +19 -3
  147. snowflake/ml/modeling/model_selection/grid_search_cv.py +3 -13
  148. snowflake/ml/modeling/model_selection/randomized_search_cv.py +3 -13
  149. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +19 -3
  150. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +19 -3
  151. snowflake/ml/modeling/multiclass/output_code_classifier.py +19 -3
  152. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +19 -3
  153. snowflake/ml/modeling/naive_bayes/categorical_nb.py +19 -3
  154. snowflake/ml/modeling/naive_bayes/complement_nb.py +19 -3
  155. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +19 -3
  156. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +19 -3
  157. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +19 -3
  158. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +19 -3
  159. snowflake/ml/modeling/neighbors/kernel_density.py +19 -3
  160. snowflake/ml/modeling/neighbors/local_outlier_factor.py +19 -3
  161. snowflake/ml/modeling/neighbors/nearest_centroid.py +19 -3
  162. snowflake/ml/modeling/neighbors/nearest_neighbors.py +19 -3
  163. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +19 -3
  164. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +19 -3
  165. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +19 -3
  166. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +19 -3
  167. snowflake/ml/modeling/neural_network/mlp_classifier.py +19 -3
  168. snowflake/ml/modeling/neural_network/mlp_regressor.py +19 -3
  169. snowflake/ml/modeling/preprocessing/polynomial_features.py +19 -3
  170. snowflake/ml/modeling/semi_supervised/label_propagation.py +19 -3
  171. snowflake/ml/modeling/semi_supervised/label_spreading.py +19 -3
  172. snowflake/ml/modeling/svm/linear_svc.py +19 -3
  173. snowflake/ml/modeling/svm/linear_svr.py +19 -3
  174. snowflake/ml/modeling/svm/nu_svc.py +19 -3
  175. snowflake/ml/modeling/svm/nu_svr.py +19 -3
  176. snowflake/ml/modeling/svm/svc.py +19 -3
  177. snowflake/ml/modeling/svm/svr.py +19 -3
  178. snowflake/ml/modeling/tree/decision_tree_classifier.py +19 -3
  179. snowflake/ml/modeling/tree/decision_tree_regressor.py +19 -3
  180. snowflake/ml/modeling/tree/extra_tree_classifier.py +19 -3
  181. snowflake/ml/modeling/tree/extra_tree_regressor.py +19 -3
  182. snowflake/ml/modeling/xgboost/xgb_classifier.py +19 -3
  183. snowflake/ml/modeling/xgboost/xgb_regressor.py +19 -3
  184. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +19 -3
  185. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +19 -3
  186. snowflake/ml/registry/registry.py +2 -0
  187. snowflake/ml/version.py +1 -1
  188. snowflake_ml_python-1.2.2.dist-info/LICENSE.txt +202 -0
  189. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/METADATA +276 -50
  190. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/RECORD +204 -197
  191. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/WHEEL +2 -1
  192. snowflake_ml_python-1.2.2.dist-info/top_level.txt +1 -0
  193. /snowflake/ml/modeling/_internal/{pandas_trainer.py → local_implementations/pandas_trainer.py} +0 -0
  194. /snowflake/ml/modeling/_internal/{snowpark_trainer.py → snowpark_implementations/snowpark_trainer.py} +0 -0
@@ -4,11 +4,12 @@ import io
4
4
  import os
5
5
  import posixpath
6
6
  import sys
7
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import cloudpickle as cp
10
10
  import numpy as np
11
11
  from sklearn import model_selection
12
+ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
12
13
 
13
14
  from snowflake.ml._internal import telemetry
14
15
  from snowflake.ml._internal.utils import (
@@ -23,7 +24,9 @@ from snowflake.ml._internal.utils.temp_file_utils import (
23
24
  from snowflake.ml.modeling._internal.model_specifications import (
24
25
  ModelSpecificationsBuilder,
25
26
  )
26
- from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
27
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_trainer import (
28
+ SnowparkModelTrainer,
29
+ )
27
30
  from snowflake.snowpark import DataFrame, Session, functions as F
28
31
  from snowflake.snowpark._internal.utils import (
29
32
  TempObjectType,
@@ -41,23 +44,28 @@ DEFAULT_UDTF_NJOBS = 3
41
44
 
42
45
 
43
46
  def construct_cv_results(
47
+ estimator: Union[GridSearchCV, RandomizedSearchCV],
48
+ n_split: int,
49
+ param_grid: List[Dict[str, Any]],
44
50
  cv_results_raw_hex: List[Row],
45
51
  cross_validator_indices_length: int,
46
52
  parameter_grid_length: int,
47
- search_cv_kwargs: Dict[str, Any],
48
- ) -> Tuple[bool, Dict[str, Any], int, Set[str]]:
53
+ ) -> Tuple[bool, Dict[str, Any]]:
49
54
  """Construct the cross validation result from the UDF. Because we accelerate the process
50
55
  by the number of cross validation number, and the combination of parameter grids.
51
56
  Therefore, we need to stick them back together instead of returning the raw result
52
57
  to align with original sklearn result.
53
58
 
54
59
  Args:
60
+ estimator (Union[GridSearchCV, RandomizedSearchCV]): The sklearn object of estimator
61
+ GridSearchCV or RandomizedSearchCV
62
+ n_split (int): The number of split, which is determined by build_cross_validator.get_n_splits(X, y, groups)
63
+ param_grid (List[Dict[str, Any]]): the list of parameter grid or parameter sampler
55
64
  cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination.
56
65
  Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a
57
66
  json format. Each cv_result is encoded into hex string.
58
67
  cross_validator_indices_length (int): the length of cross validator indices
59
68
  parameter_grid_length (int): the length of parameter grid combination
60
- search_cv_kwargs (Dict[str, Any]): the kwargs for GridSearchCV/RandomSearchCV.
61
69
 
62
70
  Raises:
63
71
  ValueError: Retrieved empty cross validation results
@@ -67,7 +75,7 @@ def construct_cv_results(
67
75
  RuntimeError: Cross validation results are unexpectedly empty for one fold.
68
76
 
69
77
  Returns:
70
- Tuple[bool, Dict[str, Any], int, Set[str]]: returns multimetric, cv_results_, best_param_index, scorers
78
+ Tuple[bool, Dict[str, Any]]: returns multimetric, cv_results_
71
79
  """
72
80
  # Filter corner cases: either the snowpark dataframe result is empty; or index length is empty
73
81
  if len(cv_results_raw_hex) == 0:
@@ -79,12 +87,8 @@ def construct_cv_results(
79
87
  if parameter_grid_length == 0:
80
88
  raise ValueError("Parameter index length is 0. Were there no candidates?")
81
89
 
82
- from scipy.stats import rankdata
83
-
84
90
  # cv_result maintains the original order
85
91
  multimetric = False
86
- cv_results_ = dict()
87
- scorers = set()
88
92
  # retrieve the cv_results from udtf table; results are encoded by hex and cloudpickle;
89
93
  # We are constructing the raw information back to original form
90
94
  if len(cv_results_raw_hex) != cross_validator_indices_length * parameter_grid_length:
@@ -94,7 +98,9 @@ def construct_cv_results(
94
98
  "Please retry or contact snowflake support."
95
99
  )
96
100
 
97
- for param_cv_indices, each_cv_result_hex in enumerate(cv_results_raw_hex):
101
+ out = []
102
+
103
+ for each_cv_result_hex in cv_results_raw_hex:
98
104
  # convert the hex string back to cv_results_
99
105
  hex_str = bytes.fromhex(each_cv_result_hex[0])
100
106
  with io.BytesIO(hex_str) as f_reload:
@@ -103,85 +109,46 @@ def construct_cv_results(
103
109
  raise RuntimeError(
104
110
  "Cross validation response is empty. This issue may be temporary - please try again."
105
111
  )
106
- for k, v in each_cv_result.items():
107
- cur_cv_idx = param_cv_indices % cross_validator_indices_length
108
- key = k
109
- if "split0_test_" in k:
112
+ temp_dict = dict()
113
+ """
114
+ This dictionary has the following keys
115
+ train_scores : dict of scorer name -> float
116
+ Score on training set (for all the scorers),
117
+ returned only if `return_train_score` is `True`.
118
+ test_scores : dict of scorer name -> float
119
+ Score on testing set (for all the scorers).
120
+ fit_time : float
121
+ Time spent for fitting in seconds.
122
+ score_time : float
123
+ Time spent for scoring in seconds.
124
+ """
125
+ if estimator.return_train_score:
126
+ if each_cv_result.get("split0_train_score", None):
127
+ # for single scorer, the split0_train_score only contains an array with one value
128
+ temp_dict["train_scores"] = each_cv_result["split0_train_score"][0]
129
+ else:
130
+ # if multimetric situation, the format would be
131
+ # {metric_name1: value, metric_name2: value, ...}
132
+ temp_dict["train_scores"] = {}
110
133
  # For multi-metric evaluation, the scores for all the scorers are available in the
111
134
  # cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
112
135
  # instead of '_score'.
113
- scorers.add(k[len("split0_test_") :])
114
- key = k.replace("split0_test", f"split{cur_cv_idx}_test")
115
- if search_cv_kwargs.get("return_train_score", None) and "split0_train_" in k:
116
- key = k.replace("split0_train", f"split{cur_cv_idx}_train")
117
- elif k.startswith("param"):
118
- if cur_cv_idx != 0:
119
- continue
120
- if key:
121
- if key not in cv_results_:
122
- cv_results_[key] = v
123
- else:
124
- cv_results_[key] = np.concatenate([cv_results_[key], v])
125
-
126
- multimetric = len(scorers) > 1
127
- # Use numpy to re-calculate all the information in cv_results_ again
128
- # Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape,
129
- # and average them by the idx_length;
130
- # idx_length is the number of cv folds; params_length is the number of parameter combinations
131
- scores_test = [
132
- np.reshape(
133
- np.concatenate(
134
- [cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(cross_validator_indices_length)]
135
- ),
136
- (cross_validator_indices_length, -1),
137
- )
138
- for score in scorers
139
- ]
140
-
141
- fit_score_test_matrix = np.stack(
142
- [
143
- np.reshape(cv_results_["mean_fit_time"], (cross_validator_indices_length, -1)),
144
- np.reshape(cv_results_["mean_score_time"], (cross_validator_indices_length, -1)),
145
- ]
146
- + scores_test
147
- )
148
- mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1)
149
- std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1)
150
-
151
- if search_cv_kwargs.get("return_train_score", None):
152
- scores_train = [
153
- np.reshape(
154
- np.concatenate(
155
- [cv_results_[f"split{cur_cv}_train_{score}"] for cur_cv in range(cross_validator_indices_length)]
156
- ),
157
- (cross_validator_indices_length, -1),
158
- )
159
- for score in scorers
160
- ]
161
- mean_fit_score_train_matrix = np.mean(scores_train, axis=1)
162
- std_fit_score_train_matrix = np.std(scores_train, axis=1)
163
-
164
- cv_results_["std_fit_time"] = std_fit_score_test_matrix[0]
165
- cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0]
166
- cv_results_["std_score_time"] = std_fit_score_test_matrix[1]
167
- cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1]
168
- for idx, score in enumerate(scorers):
169
- cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2]
170
- cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2]
171
- if search_cv_kwargs.get("return_train_score", None):
172
- cv_results_[f"std_train_{score}"] = std_fit_score_train_matrix[idx]
173
- cv_results_[f"mean_train_{score}"] = mean_fit_score_train_matrix[idx]
174
- # re-compute the ranking again with mean_test_<score>.
175
- cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min")
176
- # The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared.
177
- # If all scores are `nan`, `rankdata` will also produce an array of `nan` values.
178
- # In that case, default to first index.
179
- best_param_index = (
180
- np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0]
181
- if not np.isnan(cv_results_[f"rank_test_{score}"]).all()
182
- else 0
183
- )
184
- return multimetric, cv_results_, best_param_index, scorers
136
+ for k, v in each_cv_result.items():
137
+ if "split0_train_" in k:
138
+ temp_dict["train_scores"][k[len("split0_train_") :]] = v
139
+ if isinstance(each_cv_result.get("split0_test_score"), np.ndarray):
140
+ temp_dict["test_scores"] = each_cv_result["split0_test_score"][0]
141
+ else:
142
+ temp_dict["test_scores"] = {}
143
+ for k, v in each_cv_result.items():
144
+ if "split0_test_" in k:
145
+ temp_dict["test_scores"][k[len("split0_test_") :]] = v
146
+ temp_dict["fit_time"] = each_cv_result["mean_fit_time"][0]
147
+ temp_dict["score_time"] = each_cv_result["mean_score_time"][0]
148
+ out.append(temp_dict)
149
+ first_test_score = out[0]["test_scores"]
150
+ multimetric = isinstance(first_test_score, dict)
151
+ return multimetric, estimator._format_results(param_grid, n_split, out)
185
152
 
186
153
 
187
154
  cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
@@ -288,7 +255,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
288
255
  inspect.currentframe(), self.__class__.__name__
289
256
  ),
290
257
  api_calls=[sproc],
291
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
292
258
  )
293
259
  udtf_statement_params = telemetry.get_function_usage_statement_params(
294
260
  project=_PROJECT,
@@ -297,7 +263,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
297
263
  inspect.currentframe(), self.__class__.__name__
298
264
  ),
299
265
  api_calls=[udtf],
300
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
266
+ custom_tags=dict([("hpo_udtf", True)]),
301
267
  )
302
268
 
303
269
  # Put locally serialized estimator on stage.
@@ -375,8 +341,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
375
341
  estimator = cp.load(local_estimator_file_obj)["estimator"]
376
342
 
377
343
  build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
344
+ from sklearn.utils.validation import indexable
345
+
346
+ X, y, _ = indexable(X, y, None)
347
+ n_splits = build_cross_validator.get_n_splits(X, y, None)
378
348
  # store the cross_validator's test indices only to save space
379
- cross_validator_indices = [test for _, test in build_cross_validator.split(X, y)]
349
+ cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
380
350
  local_indices_file_name = get_temp_file_path()
381
351
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
382
352
  cp.dump(cross_validator_indices, local_indices_file_obj)
@@ -529,14 +499,14 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
529
499
  )
530
500
  ),
531
501
  )
532
-
533
- multimetric, cv_results_, best_param_index, scorers = construct_cv_results(
502
+ # multimetric, cv_results_, best_param_index, scorers
503
+ multimetric, cv_results_ = construct_cv_results(
504
+ estimator,
505
+ n_splits,
506
+ list(param_grid),
534
507
  HP_raw_results.select("CV_RESULTS").sort(F.col("PARAM_CV_IND")).collect(),
535
508
  cross_validator_indices_length,
536
509
  parameter_grid_length,
537
- {
538
- "return_train_score": estimator.return_train_score,
539
- }, # TODO(xjiang): support more kwargs in here
540
510
  )
541
511
 
542
512
  estimator.cv_results_ = cv_results_
@@ -568,7 +538,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
568
538
  # With a non-custom callable, we can select the best score
569
539
  # based on the best index
570
540
  estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
571
- estimator.best_params_ = cv_results_["params"][best_param_index]
541
+ estimator.best_params_ = cv_results_["params"][estimator.best_index_]
572
542
 
573
543
  if original_refit:
574
544
  estimator.best_estimator_ = clone(estimator.estimator).set_params(
@@ -306,7 +306,7 @@ class SnowparkHandlers:
306
306
  input_cols: List[str],
307
307
  label_cols: List[str],
308
308
  sample_weight_col: Optional[str],
309
- statement_params: Dict[str, str],
309
+ score_statement_params: Dict[str, str],
310
310
  ) -> float:
311
311
  import inspect
312
312
  import os
@@ -317,13 +317,13 @@ class SnowparkHandlers:
317
317
  importlib.import_module(import_name)
318
318
 
319
319
  for query in sql_queries[:-1]:
320
- _ = session.sql(query).collect(statement_params=statement_params)
320
+ _ = session.sql(query).collect(statement_params=score_statement_params)
321
321
  sp_df = session.sql(sql_queries[-1])
322
- df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
322
+ df: pd.DataFrame = sp_df.to_pandas(statement_params=score_statement_params)
323
323
  df.columns = sp_df.columns
324
324
 
325
325
  local_score_file_name = get_temp_file_path()
326
- session.file.get(stage_score_file_name, local_score_file_name, statement_params=statement_params)
326
+ session.file.get(stage_score_file_name, local_score_file_name, statement_params=score_statement_params)
327
327
 
328
328
  local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
329
329
  with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
@@ -348,7 +348,7 @@ class SnowparkHandlers:
348
348
  return result
349
349
 
350
350
  # Call score sproc
351
- statement_params = telemetry.get_function_usage_statement_params(
351
+ score_statement_params = telemetry.get_function_usage_statement_params(
352
352
  project=_PROJECT,
353
353
  subproject=self._subproject,
354
354
  function_name=telemetry.get_statement_params_full_func_name(
@@ -357,6 +357,8 @@ class SnowparkHandlers:
357
357
  api_calls=[Session.call],
358
358
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
359
359
  )
360
+
361
+ kwargs = telemetry.get_sproc_statement_params_kwargs(score_wrapper_sproc, score_statement_params)
360
362
  score: float = score_wrapper_sproc(
361
363
  session,
362
364
  queries,
@@ -364,7 +366,8 @@ class SnowparkHandlers:
364
366
  input_cols,
365
367
  label_cols,
366
368
  sample_weight_col,
367
- statement_params,
369
+ score_statement_params,
370
+ **kwargs,
368
371
  )
369
372
 
370
373
  cleanup_temp_files([local_score_file_name])
@@ -23,7 +23,9 @@ from snowflake.ml.modeling._internal.model_specifications import (
23
23
  ModelSpecifications,
24
24
  ModelSpecificationsBuilder,
25
25
  )
26
- from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
26
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_trainer import (
27
+ SnowparkModelTrainer,
28
+ )
27
29
  from snowflake.snowpark import (
28
30
  DataFrame,
29
31
  Session,
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -229,7 +229,7 @@ class CalibratedClassifierCV(BaseTransformer):
229
229
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
230
230
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
231
231
  self._snowpark_cols: Optional[List[str]] = self.input_cols
232
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=CalibratedClassifierCV.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
232
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=CalibratedClassifierCV.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
233
233
  self._autogenerated = True
234
234
 
235
235
  def _get_rand_id(self) -> str:
@@ -589,6 +589,22 @@ class CalibratedClassifierCV(BaseTransformer):
589
589
  # each row containing a list of values.
590
590
  expected_dtype = "ARRAY"
591
591
 
592
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
593
+ if expected_dtype == "":
594
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
595
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
596
+ expected_dtype = "ARRAY"
597
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
598
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
599
+ expected_dtype = "ARRAY"
600
+ else:
601
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
602
+ # We can only infer the output types from the input types if the following two statemetns are true:
603
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
604
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
605
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
606
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
607
+
592
608
  output_df = self._batch_inference(
593
609
  dataset=dataset,
594
610
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -204,7 +204,7 @@ class AffinityPropagation(BaseTransformer):
204
204
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
205
205
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
206
206
  self._snowpark_cols: Optional[List[str]] = self.input_cols
207
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=AffinityPropagation.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
207
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=AffinityPropagation.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
208
208
  self._autogenerated = True
209
209
 
210
210
  def _get_rand_id(self) -> str:
@@ -564,6 +564,22 @@ class AffinityPropagation(BaseTransformer):
564
564
  # each row containing a list of values.
565
565
  expected_dtype = "ARRAY"
566
566
 
567
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
568
+ if expected_dtype == "":
569
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
570
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
571
+ expected_dtype = "ARRAY"
572
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
573
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
574
+ expected_dtype = "ARRAY"
575
+ else:
576
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
577
+ # We can only infer the output types from the input types if the following two statemetns are true:
578
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
579
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
580
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
581
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
582
+
567
583
  output_df = self._batch_inference(
568
584
  dataset=dataset,
569
585
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -237,7 +237,7 @@ class AgglomerativeClustering(BaseTransformer):
237
237
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
238
238
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
239
239
  self._snowpark_cols: Optional[List[str]] = self.input_cols
240
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=AgglomerativeClustering.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
240
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=AgglomerativeClustering.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
241
241
  self._autogenerated = True
242
242
 
243
243
  def _get_rand_id(self) -> str:
@@ -595,6 +595,22 @@ class AgglomerativeClustering(BaseTransformer):
595
595
  # each row containing a list of values.
596
596
  expected_dtype = "ARRAY"
597
597
 
598
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
599
+ if expected_dtype == "":
600
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
601
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
602
+ expected_dtype = "ARRAY"
603
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
604
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
605
+ expected_dtype = "ARRAY"
606
+ else:
607
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
608
+ # We can only infer the output types from the input types if the following two statemetns are true:
609
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
610
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
611
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
612
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
613
+
598
614
  output_df = self._batch_inference(
599
615
  dataset=dataset,
600
616
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -195,7 +195,7 @@ class Birch(BaseTransformer):
195
195
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
196
196
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
197
197
  self._snowpark_cols: Optional[List[str]] = self.input_cols
198
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=Birch.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
198
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=Birch.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
199
199
  self._autogenerated = True
200
200
 
201
201
  def _get_rand_id(self) -> str:
@@ -557,6 +557,22 @@ class Birch(BaseTransformer):
557
557
  # each row containing a list of values.
558
558
  expected_dtype = "ARRAY"
559
559
 
560
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
561
+ if expected_dtype == "":
562
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
563
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
566
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
567
+ expected_dtype = "ARRAY"
568
+ else:
569
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
570
+ # We can only infer the output types from the input types if the following two statemetns are true:
571
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
572
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
573
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
574
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
575
+
560
576
  output_df = self._batch_inference(
561
577
  dataset=dataset,
562
578
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -244,7 +244,7 @@ class BisectingKMeans(BaseTransformer):
244
244
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
245
245
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
246
246
  self._snowpark_cols: Optional[List[str]] = self.input_cols
247
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=BisectingKMeans.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
247
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=BisectingKMeans.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
248
248
  self._autogenerated = True
249
249
 
250
250
  def _get_rand_id(self) -> str:
@@ -606,6 +606,22 @@ class BisectingKMeans(BaseTransformer):
606
606
  # each row containing a list of values.
607
607
  expected_dtype = "ARRAY"
608
608
 
609
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
610
+ if expected_dtype == "":
611
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
612
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
613
+ expected_dtype = "ARRAY"
614
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
615
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
616
+ expected_dtype = "ARRAY"
617
+ else:
618
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
619
+ # We can only infer the output types from the input types if the following two statemetns are true:
620
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
621
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
622
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
623
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
624
+
609
625
  output_df = self._batch_inference(
610
626
  dataset=dataset,
611
627
  inference_method="transform",
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -212,7 +212,7 @@ class DBSCAN(BaseTransformer):
212
212
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
213
213
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
214
214
  self._snowpark_cols: Optional[List[str]] = self.input_cols
215
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=DBSCAN.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
215
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=DBSCAN.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
216
216
  self._autogenerated = True
217
217
 
218
218
  def _get_rand_id(self) -> str:
@@ -570,6 +570,22 @@ class DBSCAN(BaseTransformer):
570
570
  # each row containing a list of values.
571
571
  expected_dtype = "ARRAY"
572
572
 
573
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
574
+ if expected_dtype == "":
575
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
576
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
577
+ expected_dtype = "ARRAY"
578
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
579
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
580
+ expected_dtype = "ARRAY"
581
+ else:
582
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
583
+ # We can only infer the output types from the input types if the following two statemetns are true:
584
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
585
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
586
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
587
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
588
+
573
589
  output_df = self._batch_inference(
574
590
  dataset=dataset,
575
591
  inference_method="transform",