snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  10. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  11. snowflake/ml/dataset/dataset.py +1 -1
  12. snowflake/ml/model/_api.py +21 -14
  13. snowflake/ml/model/_client/model/model_impl.py +176 -0
  14. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  15. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  16. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  18. snowflake/ml/model/_client/sql/model.py +75 -0
  19. snowflake/ml/model/_client/sql/model_version.py +213 -0
  20. snowflake/ml/model/_client/sql/stage.py +40 -0
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  22. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  23. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  24. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  25. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  26. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  27. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  28. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  31. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  32. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  33. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  34. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  36. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  37. snowflake/ml/model/model_signature.py +108 -53
  38. snowflake/ml/model/type_hints.py +1 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  40. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  41. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  42. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  43. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  44. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  45. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  46. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  47. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
  48. snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
  49. snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
  50. snowflake/ml/modeling/cluster/birch.py +94 -124
  51. snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
  52. snowflake/ml/modeling/cluster/dbscan.py +94 -124
  53. snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
  54. snowflake/ml/modeling/cluster/k_means.py +93 -124
  55. snowflake/ml/modeling/cluster/mean_shift.py +94 -124
  56. snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
  57. snowflake/ml/modeling/cluster/optics.py +94 -124
  58. snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
  59. snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
  60. snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
  61. snowflake/ml/modeling/compose/column_transformer.py +94 -124
  62. snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
  63. snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
  64. snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
  65. snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
  66. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
  67. snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
  68. snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
  69. snowflake/ml/modeling/covariance/oas.py +80 -110
  70. snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
  71. snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
  72. snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
  73. snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
  74. snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
  75. snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
  76. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
  77. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
  78. snowflake/ml/modeling/decomposition/pca.py +94 -124
  79. snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
  80. snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
  81. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
  82. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
  83. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
  84. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
  85. snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
  86. snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
  87. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
  88. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
  89. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
  90. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
  93. snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
  94. snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
  95. snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
  96. snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
  97. snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
  98. snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
  100. snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
  101. snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
  102. snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
  103. snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
  104. snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
  105. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
  106. snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
  107. snowflake/ml/modeling/framework/base.py +2 -2
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
  110. snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
  111. snowflake/ml/modeling/impute/knn_imputer.py +94 -124
  112. snowflake/ml/modeling/impute/missing_indicator.py +94 -124
  113. snowflake/ml/modeling/impute/simple_imputer.py +1 -1
  114. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
  115. snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
  116. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
  117. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
  118. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
  119. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
  120. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
  121. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
  122. snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
  123. snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
  124. snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
  125. snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
  126. snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
  127. snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
  128. snowflake/ml/modeling/linear_model/lars.py +96 -124
  129. snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
  130. snowflake/ml/modeling/linear_model/lasso.py +96 -124
  131. snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
  132. snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
  133. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
  134. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
  135. snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
  136. snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
  137. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
  140. snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
  141. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
  142. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
  143. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
  144. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
  145. snowflake/ml/modeling/linear_model/perceptron.py +95 -124
  146. snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
  147. snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
  148. snowflake/ml/modeling/linear_model/ridge.py +96 -124
  149. snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
  150. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
  151. snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
  152. snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
  153. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
  154. snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
  155. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
  156. snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
  157. snowflake/ml/modeling/manifold/isomap.py +94 -124
  158. snowflake/ml/modeling/manifold/mds.py +94 -124
  159. snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
  160. snowflake/ml/modeling/manifold/tsne.py +94 -124
  161. snowflake/ml/modeling/metrics/classification.py +187 -52
  162. snowflake/ml/modeling/metrics/correlation.py +4 -2
  163. snowflake/ml/modeling/metrics/covariance.py +7 -4
  164. snowflake/ml/modeling/metrics/ranking.py +32 -16
  165. snowflake/ml/modeling/metrics/regression.py +60 -32
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
  180. snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
  190. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  191. snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
  192. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
  193. snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
  194. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
  195. snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
  196. snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
  197. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
  198. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
  199. snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
  200. snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
  201. snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
  202. snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
  203. snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
  204. snowflake/ml/modeling/svm/linear_svc.py +96 -124
  205. snowflake/ml/modeling/svm/linear_svr.py +96 -124
  206. snowflake/ml/modeling/svm/nu_svc.py +96 -124
  207. snowflake/ml/modeling/svm/nu_svr.py +96 -124
  208. snowflake/ml/modeling/svm/svc.py +96 -124
  209. snowflake/ml/modeling/svm/svr.py +96 -124
  210. snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
  211. snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
  212. snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
  213. snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
  214. snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
  215. snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
  216. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
  217. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
  218. snowflake/ml/registry/model_registry.py +2 -0
  219. snowflake/ml/registry/registry.py +215 -0
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
  222. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  223. snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
  224. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,331 @@
1
+ import importlib
2
+ import inspect
3
+ import os
4
+ import posixpath
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple
6
+
7
+ import cloudpickle as cp
8
+
9
+ from snowflake.ml._internal import telemetry
10
+ from snowflake.ml._internal.exceptions import (
11
+ error_codes,
12
+ exceptions,
13
+ modeling_error_messages,
14
+ )
15
+ from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
16
+ from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
17
+ from snowflake.ml._internal.utils.temp_file_utils import (
18
+ cleanup_temp_files,
19
+ get_temp_file_path,
20
+ )
21
+ from snowflake.ml.modeling._internal.model_specifications import (
22
+ ModelSpecifications,
23
+ ModelSpecificationsBuilder,
24
+ )
25
+ from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions
26
+ from snowflake.snowpark._internal.utils import (
27
+ TempObjectType,
28
+ random_name_for_temp_object,
29
+ )
30
+ from snowflake.snowpark.functions import sproc
31
+ from snowflake.snowpark.stored_procedure import StoredProcedure
32
+
33
+ cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
34
+ cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
35
+
36
+ _PROJECT = "ModelDevelopment"
37
+
38
+
39
+ class SnowparkModelTrainer:
40
+ """
41
+ A class for training models on Snowflake data using the Sproc.
42
+
43
+ TODO (snandamuri): Introduce the concept of executor that would take the training function
44
+ and execute it on the target environments like, local, Snowflake warehouse, or SPCS, etc.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ estimator: object,
50
+ dataset: DataFrame,
51
+ session: Session,
52
+ input_cols: List[str],
53
+ label_cols: Optional[List[str]],
54
+ sample_weight_col: Optional[str],
55
+ autogenerated: bool = False,
56
+ subproject: str = "",
57
+ ) -> None:
58
+ """
59
+ Initializes the SnowparkModelTrainer with a model, a Snowpark DataFrame, feature, and label column names.
60
+
61
+ Args:
62
+ estimator: SKLearn compatible estimator or transformer object.
63
+ dataset: The dataset used for training the model.
64
+ session: Snowflake session object to be used for training.
65
+ input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training.
66
+ label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn.
67
+ sample_weight_col: The column name representing the weight of training examples.
68
+ autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not.
69
+ subproject: subproject name to be used in telemetry.
70
+ """
71
+ self.estimator = estimator
72
+ self.dataset = dataset
73
+ self.session = session
74
+ self.input_cols = input_cols
75
+ self.label_cols = label_cols
76
+ self.sample_weight_col = sample_weight_col
77
+ self._autogenerated = autogenerated
78
+ self._subproject = subproject
79
+ self._class_name = estimator.__class__.__name__
80
+
81
+ def _create_temp_stage(self) -> str:
82
+ """
83
+ Creates temporary stage.
84
+
85
+ Returns:
86
+ Temp stage name.
87
+ """
88
+ # Create temp stage to upload pickled model file.
89
+ transform_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
90
+ stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
91
+ SqlResultValidator(session=self.session, query=stage_creation_query).has_dimensions(
92
+ expected_rows=1, expected_cols=1
93
+ ).validate()
94
+ return transform_stage_name
95
+
96
+ def _upload_model_to_stage(self, stage_name: str) -> Tuple[str, str]:
97
+ """
98
+ Util method to pickle and upload the model to a temp Snowflake stage.
99
+
100
+ Args:
101
+ stage_name: Stage name to save model.
102
+
103
+ Returns:
104
+ a tuple containing stage file paths for pickled input model for training and location to store trained
105
+ models(response from training sproc).
106
+ """
107
+ # Create a temp file and dump the transform to that file.
108
+ local_transform_file_name = get_temp_file_path()
109
+ with open(local_transform_file_name, mode="w+b") as local_transform_file:
110
+ cp.dump(self.estimator, local_transform_file)
111
+
112
+ # Use posixpath to construct stage paths
113
+ stage_transform_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
114
+ stage_result_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
115
+
116
+ statement_params = telemetry.get_function_usage_statement_params(
117
+ project=_PROJECT,
118
+ subproject=self._subproject,
119
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
120
+ api_calls=[sproc],
121
+ custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
122
+ )
123
+ # Put locally serialized transform on stage.
124
+ self.session.file.put(
125
+ local_transform_file_name,
126
+ stage_transform_file_name,
127
+ auto_compress=False,
128
+ overwrite=True,
129
+ statement_params=statement_params,
130
+ )
131
+
132
+ cleanup_temp_files([local_transform_file_name])
133
+ return (stage_transform_file_name, stage_result_file_name)
134
+
135
+ def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object:
136
+ """
137
+ Downloads the serialized model from a stage location and unpickels it.
138
+
139
+ Args:
140
+ dir_path: Stage directory path where results are stored.
141
+ file_name: File name with in the directory where results are stored.
142
+ statement_params: Statement params to be attached to the SQL queries issue form this method.
143
+
144
+ Returns:
145
+ Deserialized model object.
146
+ """
147
+ local_result_file_name = get_temp_file_path()
148
+ self.session.file.get(
149
+ posixpath.join(dir_path, file_name),
150
+ local_result_file_name,
151
+ statement_params=statement_params,
152
+ )
153
+
154
+ with open(os.path.join(local_result_file_name, file_name), mode="r+b") as result_file_obj:
155
+ fit_estimator = cp.load(result_file_obj)
156
+
157
+ cleanup_temp_files([local_result_file_name])
158
+ return fit_estimator
159
+
160
+ def _build_fit_wrapper_sproc(
161
+ self,
162
+ model_spec: ModelSpecifications,
163
+ ) -> Callable[[Any, List[str], str, str, List[str], List[str], Optional[str], Dict[str, str]], str]:
164
+ """
165
+ Constructs and returns a python stored procedure function to be used for training model.
166
+
167
+ Args:
168
+ model_spec: ModelSpecifications object that contains model specific information
169
+ like required imports, package dependencies, etc.
170
+
171
+ Returns:
172
+ A callable that can be registered as a stored procedure.
173
+ """
174
+ imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
175
+
176
+ def fit_wrapper_function(
177
+ session: Session,
178
+ sql_queries: List[str],
179
+ stage_transform_file_name: str,
180
+ stage_result_file_name: str,
181
+ input_cols: List[str],
182
+ label_cols: List[str],
183
+ sample_weight_col: Optional[str],
184
+ statement_params: Dict[str, str],
185
+ ) -> str:
186
+ import inspect
187
+ import os
188
+
189
+ import cloudpickle as cp
190
+ import pandas as pd
191
+
192
+ for import_name in imports:
193
+ importlib.import_module(import_name)
194
+
195
+ # Execute snowpark queries and obtain the results as pandas dataframe
196
+ # NB: this implies that the result data must fit into memory.
197
+ for query in sql_queries[:-1]:
198
+ _ = session.sql(query).collect(statement_params=statement_params)
199
+ sp_df = session.sql(sql_queries[-1])
200
+ df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
201
+ df.columns = sp_df.columns
202
+
203
+ local_transform_file_name = get_temp_file_path()
204
+
205
+ session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
206
+
207
+ local_transform_file_path = os.path.join(
208
+ local_transform_file_name, os.listdir(local_transform_file_name)[0]
209
+ )
210
+ with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
211
+ estimator = cp.load(local_transform_file_obj)
212
+
213
+ argspec = inspect.getfullargspec(estimator.fit)
214
+ args = {"X": df[input_cols]}
215
+ if label_cols:
216
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
217
+ args[label_arg_name] = df[label_cols].squeeze()
218
+
219
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
220
+ args["sample_weight"] = df[sample_weight_col].squeeze()
221
+
222
+ estimator.fit(**args)
223
+
224
+ local_result_file_name = get_temp_file_path()
225
+
226
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
227
+ cp.dump(estimator, local_result_file_obj)
228
+
229
+ session.file.put(
230
+ local_result_file_name,
231
+ stage_result_file_name,
232
+ auto_compress=False,
233
+ overwrite=True,
234
+ statement_params=statement_params,
235
+ )
236
+
237
+ # Note: you can add something like + "|" + str(df) to the return string
238
+ # to pass debug information to the caller.
239
+ return str(os.path.basename(local_result_file_name))
240
+
241
+ return fit_wrapper_function
242
+
243
+ def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
244
+ # If the sproc already exists, don't register.
245
+ if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
246
+ self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
247
+
248
+ model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
249
+ fit_sproc_key = model_spec.__class__.__name__
250
+ if fit_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
251
+ fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
252
+ return fit_sproc
253
+
254
+ fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
255
+
256
+ fit_wrapper_sproc = self.session.sproc.register(
257
+ func=self._build_fit_wrapper_sproc(model_spec=model_spec),
258
+ is_permanent=False,
259
+ name=fit_sproc_name,
260
+ packages=["snowflake-snowpark-python"] + model_spec.pkgDependencies, # type: ignore[arg-type]
261
+ replace=True,
262
+ session=self.session,
263
+ statement_params=statement_params,
264
+ )
265
+
266
+ self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined]
267
+
268
+ return fit_wrapper_sproc
269
+
270
+ def train(self) -> object:
271
+ """
272
+ Trains the model by pushing down the compute into Snowflake using stored procedures.
273
+
274
+ Returns:
275
+ Trained model
276
+
277
+ Raises:
278
+ e: Raises an exception if any of Snowflake operations fail because of any reason.
279
+ SnowflakeMLException: Know exception are caught and rethrow with more detailed error message.
280
+ """
281
+ dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(self.dataset)
282
+
283
+ # TODO(snandamuri) : Handle the already in a stored procedure case in the in builder.
284
+
285
+ # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
286
+ queries = dataset.queries["queries"]
287
+
288
+ transform_stage_name = self._create_temp_stage()
289
+ (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
290
+ stage_name=transform_stage_name
291
+ )
292
+
293
+ # Call fit sproc
294
+ statement_params = telemetry.get_function_usage_statement_params(
295
+ project=_PROJECT,
296
+ subproject=self._subproject,
297
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
298
+ api_calls=[Session.call],
299
+ custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
300
+ )
301
+
302
+ fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
303
+
304
+ try:
305
+ sproc_export_file_name: str = fit_wrapper_sproc(
306
+ self.session,
307
+ queries,
308
+ stage_transform_file_name,
309
+ stage_result_file_name,
310
+ self.input_cols,
311
+ self.label_cols,
312
+ self.sample_weight_col,
313
+ statement_params,
314
+ )
315
+ except snowpark_exceptions.SnowparkClientException as e:
316
+ if "fit() missing 1 required positional argument: 'y'" in str(e):
317
+ raise exceptions.SnowflakeMLException(
318
+ error_code=error_codes.NOT_FOUND,
319
+ original_exception=RuntimeError(modeling_error_messages.ATTRIBUTE_NOT_SET.format("label_cols")),
320
+ ) from e
321
+ raise e
322
+
323
+ if "|" in sproc_export_file_name:
324
+ fields = sproc_export_file_name.strip().split("|")
325
+ sproc_export_file_name = fields[0]
326
+
327
+ return self._fetch_model_from_stage(
328
+ dir_path=stage_result_file_name,
329
+ file_name=sproc_export_file_name,
330
+ statement_params=statement_params,
331
+ )
@@ -22,17 +22,19 @@ from sklearn.utils.metaestimators import available_if
22
22
  from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
23
23
  from snowflake.ml._internal import telemetry
24
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
25
+ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
25
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
- from snowflake.snowpark import DataFrame
27
+ from snowflake.snowpark import DataFrame, Session
27
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
28
29
  from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
+ from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
+ from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
29
32
  from snowflake.ml.modeling._internal.estimator_utils import (
30
33
  gather_dependencies,
31
34
  original_estimator_has_callable,
32
35
  transform_snowml_obj_to_sklearn_obj,
33
36
  validate_sklearn_args,
34
37
  )
35
- from snowflake.ml.modeling._internal.snowpark_handlers import SklearnWrapperProvider
36
38
  from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
37
39
 
38
40
  from snowflake.ml.model.model_signature import (
@@ -52,7 +54,6 @@ _PROJECT = "ModelDevelopment"
52
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.calibration".replace("sklearn.", "").split("_")])
53
55
 
54
56
 
55
-
56
57
  class CalibratedClassifierCV(BaseTransformer):
57
58
  r"""Probability calibration with isotonic regression or logistic regression
58
59
  For more details on this class, see [sklearn.calibration.CalibratedClassifierCV]
@@ -60,6 +61,51 @@ class CalibratedClassifierCV(BaseTransformer):
60
61
 
61
62
  Parameters
62
63
  ----------
64
+
65
+ input_cols: Optional[Union[str, List[str]]]
66
+ A string or list of strings representing column names that contain features.
67
+ If this parameter is not specified, all columns in the input DataFrame except
68
+ the columns specified by label_cols, sample_weight_col, and passthrough_cols
69
+ parameters are considered input columns. Input columns can also be set after
70
+ initialization with the `set_input_cols` method.
71
+
72
+ label_cols: Optional[Union[str, List[str]]]
73
+ A string or list of strings representing column names that contain labels.
74
+ Label columns must be specified with this parameter during initialization
75
+ or with the `set_label_cols` method before fitting.
76
+
77
+ output_cols: Optional[Union[str, List[str]]]
78
+ A string or list of strings representing column names that will store the
79
+ output of predict and transform operations. The length of output_cols must
80
+ match the expected number of output columns from the specific predictor or
81
+ transformer class used.
82
+ If you omit this parameter, output column names are derived by adding an
83
+ OUTPUT_ prefix to the label column names for supervised estimators, or
84
+ OUTPUT_<IDX>for unsupervised estimators. These inferred output column names
85
+ work for predictors, but output_cols must be set explicitly for transformers.
86
+ In general, explicitly specifying output column names is clearer, especially
87
+ if you don’t specify the input column names.
88
+ To transform in place, pass the same names for input_cols and output_cols.
89
+ be set explicitly for transformers. Output columns can also be set after
90
+ initialization with the `set_output_cols` method.
91
+
92
+ sample_weight_col: Optional[str]
93
+ A string representing the column name containing the sample weights.
94
+ This argument is only required when working with weighted datasets. Sample
95
+ weight column can also be set after initialization with the
96
+ `set_sample_weight_col` method.
97
+
98
+ passthrough_cols: Optional[Union[str, List[str]]]
99
+ A string or a list of strings indicating column names to be excluded from any
100
+ operations (such as train, transform, or inference). These specified column(s)
101
+ will remain untouched throughout the process. This option is helpful in scenarios
102
+ requiring automatic input_cols inference, but need to avoid using specific
103
+ columns, like index columns, during training or inference. Passthrough columns
104
+ can also be set after initialization with the `set_passthrough_cols` method.
105
+
106
+ drop_input_cols: Optional[bool], default=False
107
+ If set, the response of predict(), transform() methods will not contain input columns.
108
+
63
109
  estimator: estimator instance, default=None
64
110
  The classifier whose output need to be calibrated to provide more
65
111
  accurate `predict_proba` outputs. The default classifier is
@@ -121,42 +167,6 @@ class CalibratedClassifierCV(BaseTransformer):
121
167
 
122
168
  base_estimator: estimator instance
123
169
  This parameter is deprecated. Use `estimator` instead.
124
-
125
- input_cols: Optional[Union[str, List[str]]]
126
- A string or list of strings representing column names that contain features.
127
- If this parameter is not specified, all columns in the input DataFrame except
128
- the columns specified by label_cols, sample_weight_col, and passthrough_cols
129
- parameters are considered input columns.
130
-
131
- label_cols: Optional[Union[str, List[str]]]
132
- A string or list of strings representing column names that contain labels.
133
- This is a required param for estimators, as there is no way to infer these
134
- columns. If this parameter is not specified, then object is fitted without
135
- labels (like a transformer).
136
-
137
- output_cols: Optional[Union[str, List[str]]]
138
- A string or list of strings representing column names that will store the
139
- output of predict and transform operations. The length of output_cols must
140
- match the expected number of output columns from the specific estimator or
141
- transformer class used.
142
- If this parameter is not specified, output column names are derived by
143
- adding an OUTPUT_ prefix to the label column names. These inferred output
144
- column names work for estimator's predict() method, but output_cols must
145
- be set explicitly for transformers.
146
-
147
- sample_weight_col: Optional[str]
148
- A string representing the column name containing the sample weights.
149
- This argument is only required when working with weighted datasets.
150
-
151
- passthrough_cols: Optional[Union[str, List[str]]]
152
- A string or a list of strings indicating column names to be excluded from any
153
- operations (such as train, transform, or inference). These specified column(s)
154
- will remain untouched throughout the process. This option is helpful in scenarios
155
- requiring automatic input_cols inference, but need to avoid using specific
156
- columns, like index columns, during training or inference.
157
-
158
- drop_input_cols: Optional[bool], default=False
159
- If set, the response of predict(), transform() methods will not contain input columns.
160
170
  """
161
171
 
162
172
  def __init__( # type: ignore[no-untyped-def]
@@ -183,7 +193,7 @@ class CalibratedClassifierCV(BaseTransformer):
183
193
  self.set_passthrough_cols(passthrough_cols)
184
194
  self.set_drop_input_cols(drop_input_cols)
185
195
  self.set_sample_weight_col(sample_weight_col)
186
- deps = set(SklearnWrapperProvider().dependencies)
196
+ deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
187
197
  deps = deps | gather_dependencies(estimator)
188
198
  deps = deps | gather_dependencies(base_estimator)
189
199
  self._deps = list(deps)
@@ -199,13 +209,14 @@ class CalibratedClassifierCV(BaseTransformer):
199
209
  args=init_args,
200
210
  klass=sklearn.calibration.CalibratedClassifierCV
201
211
  )
202
- self._sklearn_object = sklearn.calibration.CalibratedClassifierCV(
212
+ self._sklearn_object: Any = sklearn.calibration.CalibratedClassifierCV(
203
213
  **cleaned_up_init_args,
204
214
  )
205
215
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
206
216
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
207
217
  self._snowpark_cols: Optional[List[str]] = self.input_cols
208
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=CalibratedClassifierCV.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True, wrapper_provider=SklearnWrapperProvider())
218
+ self._handlers: FitPredictHandlers = HandlersImpl(class_name=CalibratedClassifierCV.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
219
+ self._autogenerated = True
209
220
 
210
221
  def _get_rand_id(self) -> str:
211
222
  """
@@ -261,54 +272,48 @@ class CalibratedClassifierCV(BaseTransformer):
261
272
  self
262
273
  """
263
274
  self._infer_input_output_cols(dataset)
264
- if isinstance(dataset, pd.DataFrame):
265
- assert self._sklearn_object is not None # keep mypy happy
266
- self._sklearn_object = self._handlers.fit_pandas(
267
- dataset,
268
- self._sklearn_object,
269
- self.input_cols,
270
- self.label_cols,
271
- self.sample_weight_col
272
- )
273
- elif isinstance(dataset, DataFrame):
274
- self._fit_snowpark(dataset)
275
- else:
276
- raise TypeError(
277
- f"Unexpected dataset type: {type(dataset)}."
278
- "Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
279
- )
275
+ if isinstance(dataset, DataFrame):
276
+ session = dataset._session
277
+ assert session is not None # keep mypy happy
278
+ # Validate that key package version in user workspace are supported in snowflake conda channel
279
+ # If customer doesn't have package in conda channel, replace the ones have the closest versions
280
+ self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
281
+ pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
282
+
283
+ # Specify input columns so column pruning will be enforced
284
+ selected_cols = self._get_active_columns()
285
+ if len(selected_cols) > 0:
286
+ dataset = dataset.select(selected_cols)
287
+
288
+ self._snowpark_cols = dataset.select(self.input_cols).columns
289
+
290
+ # If we are already in a stored procedure, no need to kick off another one.
291
+ if SNOWML_SPROC_ENV in os.environ:
292
+ statement_params = telemetry.get_function_usage_statement_params(
293
+ project=_PROJECT,
294
+ subproject=_SUBPROJECT,
295
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), CalibratedClassifierCV.__class__.__name__),
296
+ api_calls=[Session.call],
297
+ custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
298
+ )
299
+ pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
300
+ pd_df.columns = dataset.columns
301
+ dataset = pd_df
302
+
303
+ model_trainer = ModelTrainerBuilder.build(
304
+ estimator=self._sklearn_object,
305
+ dataset=dataset,
306
+ input_cols=self.input_cols,
307
+ label_cols=self.label_cols,
308
+ sample_weight_col=self.sample_weight_col,
309
+ autogenerated=self._autogenerated,
310
+ subproject=_SUBPROJECT
311
+ )
312
+ self._sklearn_object = model_trainer.train()
280
313
  self._is_fitted = True
281
314
  self._get_model_signatures(dataset)
282
315
  return self
283
316
 
284
- def _fit_snowpark(self, dataset: DataFrame) -> None:
285
- session = dataset._session
286
- assert session is not None # keep mypy happy
287
- # Validate that key package version in user workspace are supported in snowflake conda channel
288
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
289
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
290
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
291
-
292
- # Specify input columns so column pruning will be enforced
293
- selected_cols = self._get_active_columns()
294
- if len(selected_cols) > 0:
295
- dataset = dataset.select(selected_cols)
296
-
297
- estimator = self._sklearn_object
298
- assert estimator is not None # Keep mypy happy
299
-
300
- self._snowpark_cols = dataset.select(self.input_cols).columns
301
-
302
- self._sklearn_object = self._handlers.fit_snowpark(
303
- dataset,
304
- session,
305
- estimator,
306
- ["snowflake-snowpark-python"] + self._get_dependencies(),
307
- self.input_cols,
308
- self.label_cols,
309
- self.sample_weight_col,
310
- )
311
-
312
317
  def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
313
318
  if self._drop_input_cols:
314
319
  return []
@@ -496,11 +501,6 @@ class CalibratedClassifierCV(BaseTransformer):
496
501
  subproject=_SUBPROJECT,
497
502
  custom_tags=dict([("autogen", True)]),
498
503
  )
499
- @telemetry.add_stmt_params_to_df(
500
- project=_PROJECT,
501
- subproject=_SUBPROJECT,
502
- custom_tags=dict([("autogen", True)]),
503
- )
504
504
  def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
505
505
  """Predict the target of new samples
506
506
  For more details on this function, see [sklearn.calibration.CalibratedClassifierCV.predict]
@@ -554,11 +554,6 @@ class CalibratedClassifierCV(BaseTransformer):
554
554
  subproject=_SUBPROJECT,
555
555
  custom_tags=dict([("autogen", True)]),
556
556
  )
557
- @telemetry.add_stmt_params_to_df(
558
- project=_PROJECT,
559
- subproject=_SUBPROJECT,
560
- custom_tags=dict([("autogen", True)]),
561
- )
562
557
  def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
563
558
  """Method not supported for this class.
564
559
 
@@ -615,7 +610,8 @@ class CalibratedClassifierCV(BaseTransformer):
615
610
  if False:
616
611
  self.fit(dataset)
617
612
  assert self._sklearn_object is not None
618
- return self._sklearn_object.labels_
613
+ labels : npt.NDArray[Any] = self._sklearn_object.labels_
614
+ return labels
619
615
  else:
620
616
  raise NotImplementedError
621
617
 
@@ -651,6 +647,7 @@ class CalibratedClassifierCV(BaseTransformer):
651
647
  output_cols = []
652
648
 
653
649
  # Make sure column names are valid snowflake identifiers.
650
+ assert output_cols is not None # Make MyPy happy
654
651
  rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
655
652
 
656
653
  return rv
@@ -661,11 +658,6 @@ class CalibratedClassifierCV(BaseTransformer):
661
658
  subproject=_SUBPROJECT,
662
659
  custom_tags=dict([("autogen", True)]),
663
660
  )
664
- @telemetry.add_stmt_params_to_df(
665
- project=_PROJECT,
666
- subproject=_SUBPROJECT,
667
- custom_tags=dict([("autogen", True)]),
668
- )
669
661
  def predict_proba(
670
662
  self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_"
671
663
  ) -> Union[DataFrame, pd.DataFrame]:
@@ -708,11 +700,6 @@ class CalibratedClassifierCV(BaseTransformer):
708
700
  subproject=_SUBPROJECT,
709
701
  custom_tags=dict([("autogen", True)]),
710
702
  )
711
- @telemetry.add_stmt_params_to_df(
712
- project=_PROJECT,
713
- subproject=_SUBPROJECT,
714
- custom_tags=dict([("autogen", True)]),
715
- )
716
703
  def predict_log_proba(
717
704
  self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_"
718
705
  ) -> Union[DataFrame, pd.DataFrame]:
@@ -751,16 +738,6 @@ class CalibratedClassifierCV(BaseTransformer):
751
738
  return output_df
752
739
 
753
740
  @available_if(original_estimator_has_callable("decision_function")) # type: ignore[misc]
754
- @telemetry.send_api_usage_telemetry(
755
- project=_PROJECT,
756
- subproject=_SUBPROJECT,
757
- custom_tags=dict([("autogen", True)]),
758
- )
759
- @telemetry.add_stmt_params_to_df(
760
- project=_PROJECT,
761
- subproject=_SUBPROJECT,
762
- custom_tags=dict([("autogen", True)]),
763
- )
764
741
  def decision_function(
765
742
  self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_"
766
743
  ) -> Union[DataFrame, pd.DataFrame]:
@@ -861,11 +838,6 @@ class CalibratedClassifierCV(BaseTransformer):
861
838
  subproject=_SUBPROJECT,
862
839
  custom_tags=dict([("autogen", True)]),
863
840
  )
864
- @telemetry.add_stmt_params_to_df(
865
- project=_PROJECT,
866
- subproject=_SUBPROJECT,
867
- custom_tags=dict([("autogen", True)]),
868
- )
869
841
  def kneighbors(
870
842
  self,
871
843
  dataset: Union[DataFrame, pd.DataFrame],
@@ -925,9 +897,9 @@ class CalibratedClassifierCV(BaseTransformer):
925
897
  # For classifier, the type of predict is the same as the type of label
926
898
  if self._sklearn_object._estimator_type == 'classifier':
927
899
  # label columns is the desired type for output
928
- outputs = _infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True)
900
+ outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
929
901
  # rename the output columns
930
- outputs = model_signature_utils.rename_features(outputs, self.output_cols)
902
+ outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
931
903
  self._model_signature_dict["predict"] = ModelSignature(inputs,
932
904
  ([] if self._drop_input_cols else inputs)
933
905
  + outputs)