snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
34
34
  BatchInferenceKwargsTypedDict,
35
35
  ScoreKwargsTypedDict
36
36
  )
37
+ from snowflake.ml.model._signatures import utils as model_signature_utils
38
+ from snowflake.ml.model.model_signature import (
39
+ BaseFeatureSpec,
40
+ DataType,
41
+ FeatureSpec,
42
+ ModelSignature,
43
+ _infer_signature,
44
+ _rename_signature_with_snowflake_identifiers,
45
+ )
37
46
 
38
47
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
39
48
 
@@ -44,16 +53,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
44
53
  validate_sklearn_args,
45
54
  )
46
55
 
47
- from snowflake.ml.model.model_signature import (
48
- DataType,
49
- FeatureSpec,
50
- ModelSignature,
51
- _infer_signature,
52
- _rename_signature_with_snowflake_identifiers,
53
- BaseFeatureSpec,
54
- )
55
- from snowflake.ml.model._signatures import utils as model_signature_utils
56
-
57
56
  _PROJECT = "ModelDevelopment"
58
57
  # Derive subproject from module name by removing "sklearn"
59
58
  # and converting module name from underscore to CamelCase
@@ -62,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
62
61
 
63
62
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
64
63
 
65
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
66
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
67
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
68
- return check
69
-
70
-
71
64
  class SelectKBest(BaseTransformer):
72
65
  r"""Select features according to the k highest scores
73
66
  For more details on this class, see [sklearn.feature_selection.SelectKBest]
@@ -206,12 +199,7 @@ class SelectKBest(BaseTransformer):
206
199
  )
207
200
  return selected_cols
208
201
 
209
- @telemetry.send_api_usage_telemetry(
210
- project=_PROJECT,
211
- subproject=_SUBPROJECT,
212
- custom_tags=dict([("autogen", True)]),
213
- )
214
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "SelectKBest":
202
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "SelectKBest":
215
203
  """Run score function on (X, y) and get the appropriate features
216
204
  For more details on this function, see [sklearn.feature_selection.SelectKBest.fit]
217
205
  (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html#sklearn.feature_selection.SelectKBest.fit)
@@ -238,12 +226,14 @@ class SelectKBest(BaseTransformer):
238
226
 
239
227
  self._snowpark_cols = dataset.select(self.input_cols).columns
240
228
 
241
- # If we are already in a stored procedure, no need to kick off another one.
229
+ # If we are already in a stored procedure, no need to kick off another one.
242
230
  if SNOWML_SPROC_ENV in os.environ:
243
231
  statement_params = telemetry.get_function_usage_statement_params(
244
232
  project=_PROJECT,
245
233
  subproject=_SUBPROJECT,
246
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), SelectKBest.__class__.__name__),
234
+ function_name=telemetry.get_statement_params_full_func_name(
235
+ inspect.currentframe(), SelectKBest.__class__.__name__
236
+ ),
247
237
  api_calls=[Session.call],
248
238
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
249
239
  )
@@ -264,27 +254,24 @@ class SelectKBest(BaseTransformer):
264
254
  )
265
255
  self._sklearn_object = model_trainer.train()
266
256
  self._is_fitted = True
267
- self._get_model_signatures(dataset)
257
+ self._generate_model_signatures(dataset)
268
258
  return self
269
259
 
270
260
  def _batch_inference_validate_snowpark(
271
261
  self,
272
262
  dataset: DataFrame,
273
263
  inference_method: str,
274
- ) -> List[str]:
275
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
276
- return the available package that exists in the snowflake anaconda channel
264
+ ) -> None:
265
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
277
266
 
278
267
  Args:
279
268
  dataset: snowpark dataframe
280
269
  inference_method: the inference method such as predict, score...
281
-
270
+
282
271
  Raises:
283
272
  SnowflakeMLException: If the estimator is not fitted, raise error
284
273
  SnowflakeMLException: If the session is None, raise error
285
274
 
286
- Returns:
287
- A list of available package that exists in the snowflake anaconda channel
288
275
  """
289
276
  if not self._is_fitted:
290
277
  raise exceptions.SnowflakeMLException(
@@ -302,9 +289,7 @@ class SelectKBest(BaseTransformer):
302
289
  "Session must not specified for snowpark dataset."
303
290
  ),
304
291
  )
305
- # Validate that key package version in user workspace are supported in snowflake conda channel
306
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
307
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
292
+
308
293
 
309
294
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
310
295
  @telemetry.send_api_usage_telemetry(
@@ -338,7 +323,9 @@ class SelectKBest(BaseTransformer):
338
323
  # when it is classifier, infer the datatype from label columns
339
324
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
340
325
  # Batch inference takes a single expected output column type. Use the first columns type for now.
341
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
326
+ label_cols_signatures = [
327
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
328
+ ]
342
329
  if len(label_cols_signatures) == 0:
343
330
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
344
331
  raise exceptions.SnowflakeMLException(
@@ -346,25 +333,23 @@ class SelectKBest(BaseTransformer):
346
333
  original_exception=ValueError(error_str),
347
334
  )
348
335
 
349
- expected_type_inferred = convert_sp_to_sf_type(
350
- label_cols_signatures[0].as_snowpark_type()
351
- )
336
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
352
337
 
353
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
354
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
338
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
339
+ self._deps = self._get_dependencies()
340
+ assert isinstance(
341
+ dataset._session, Session
342
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
355
343
 
356
344
  transform_kwargs = dict(
357
- session = dataset._session,
358
- dependencies = self._deps,
359
- drop_input_cols = self._drop_input_cols,
360
- expected_output_cols_type = expected_type_inferred,
345
+ session=dataset._session,
346
+ dependencies=self._deps,
347
+ drop_input_cols=self._drop_input_cols,
348
+ expected_output_cols_type=expected_type_inferred,
361
349
  )
362
350
 
363
351
  elif isinstance(dataset, pd.DataFrame):
364
- transform_kwargs = dict(
365
- snowpark_input_cols = self._snowpark_cols,
366
- drop_input_cols = self._drop_input_cols
367
- )
352
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
368
353
 
369
354
  transform_handlers = ModelTransformerBuilder.build(
370
355
  dataset=dataset,
@@ -406,7 +391,7 @@ class SelectKBest(BaseTransformer):
406
391
  Transformed dataset.
407
392
  """
408
393
  super()._check_dataset_type(dataset)
409
- inference_method="transform"
394
+ inference_method = "transform"
410
395
 
411
396
  # This dictionary contains optional kwargs for batch inference. These kwargs
412
397
  # are specific to the type of dataset used.
@@ -436,24 +421,19 @@ class SelectKBest(BaseTransformer):
436
421
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
437
422
  expected_dtype = convert_sp_to_sf_type(output_types[0])
438
423
 
439
- self._deps = self._batch_inference_validate_snowpark(
440
- dataset=dataset,
441
- inference_method=inference_method,
442
- )
424
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
425
+ self._deps = self._get_dependencies()
443
426
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
444
427
 
445
428
  transform_kwargs = dict(
446
- session = dataset._session,
447
- dependencies = self._deps,
448
- drop_input_cols = self._drop_input_cols,
449
- expected_output_cols_type = expected_dtype,
429
+ session=dataset._session,
430
+ dependencies=self._deps,
431
+ drop_input_cols=self._drop_input_cols,
432
+ expected_output_cols_type=expected_dtype,
450
433
  )
451
434
 
452
435
  elif isinstance(dataset, pd.DataFrame):
453
- transform_kwargs = dict(
454
- snowpark_input_cols = self._snowpark_cols,
455
- drop_input_cols = self._drop_input_cols
456
- )
436
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
457
437
 
458
438
  transform_handlers = ModelTransformerBuilder.build(
459
439
  dataset=dataset,
@@ -472,7 +452,11 @@ class SelectKBest(BaseTransformer):
472
452
  return output_df
473
453
 
474
454
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
475
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
455
+ def fit_predict(
456
+ self,
457
+ dataset: Union[DataFrame, pd.DataFrame],
458
+ output_cols_prefix: str = "fit_predict_",
459
+ ) -> Union[DataFrame, pd.DataFrame]:
476
460
  """ Method not supported for this class.
477
461
 
478
462
 
@@ -497,22 +481,106 @@ class SelectKBest(BaseTransformer):
497
481
  )
498
482
  output_result, fitted_estimator = model_trainer.train_fit_predict(
499
483
  drop_input_cols=self._drop_input_cols,
500
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
484
+ expected_output_cols_list=(
485
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
486
+ ),
501
487
  )
502
488
  self._sklearn_object = fitted_estimator
503
489
  self._is_fitted = True
504
490
  return output_result
505
491
 
492
+
493
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
494
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
495
+ """ Fit to data, then transform it
496
+ For more details on this function, see [sklearn.feature_selection.SelectKBest.fit_transform]
497
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html#sklearn.feature_selection.SelectKBest.fit_transform)
498
+
499
+
500
+ Raises:
501
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
506
502
 
507
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
508
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
509
- """
503
+ Args:
504
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
505
+ Snowpark or Pandas DataFrame.
506
+ output_cols_prefix: Prefix for the response columns
510
507
  Returns:
511
508
  Transformed dataset.
512
509
  """
513
- self.fit(dataset)
514
- assert self._sklearn_object is not None
515
- return self._sklearn_object.embedding_
510
+ self._infer_input_output_cols(dataset)
511
+ super()._check_dataset_type(dataset)
512
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
513
+ estimator=self._sklearn_object,
514
+ dataset=dataset,
515
+ input_cols=self.input_cols,
516
+ label_cols=self.label_cols,
517
+ sample_weight_col=self.sample_weight_col,
518
+ autogenerated=self._autogenerated,
519
+ subproject=_SUBPROJECT,
520
+ )
521
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
522
+ drop_input_cols=self._drop_input_cols,
523
+ expected_output_cols_list=self.output_cols,
524
+ )
525
+ self._sklearn_object = fitted_estimator
526
+ self._is_fitted = True
527
+ return output_result
528
+
529
+
530
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
531
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
532
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
533
+ """
534
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
535
+ # The following condition is introduced for kneighbors methods, and not used in other methods
536
+ if output_cols:
537
+ output_cols = [
538
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
539
+ for c in output_cols
540
+ ]
541
+ elif getattr(self._sklearn_object, "classes_", None) is None:
542
+ output_cols = [output_cols_prefix]
543
+ elif self._sklearn_object is not None:
544
+ classes = self._sklearn_object.classes_
545
+ if isinstance(classes, numpy.ndarray):
546
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
547
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
548
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
549
+ output_cols = []
550
+ for i, cl in enumerate(classes):
551
+ # For binary classification, there is only one output column for each class
552
+ # ndarray as the two classes are complementary.
553
+ if len(cl) == 2:
554
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
555
+ else:
556
+ output_cols.extend([
557
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
558
+ ])
559
+ else:
560
+ output_cols = []
561
+
562
+ # Make sure column names are valid snowflake identifiers.
563
+ assert output_cols is not None # Make MyPy happy
564
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
565
+
566
+ return rv
567
+
568
+ def _align_expected_output_names(
569
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
570
+ ) -> List[str]:
571
+ # in case the inferred output column names dimension is different
572
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
573
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
574
+ output_df_columns = list(output_df_pd.columns)
575
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
576
+ if self.sample_weight_col:
577
+ output_df_columns_set -= set(self.sample_weight_col)
578
+ # if the dimension of inferred output column names is correct; use it
579
+ if len(expected_output_cols_list) == len(output_df_columns_set):
580
+ return expected_output_cols_list
581
+ # otherwise, use the sklearn estimator's output
582
+ else:
583
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
516
584
 
517
585
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
518
586
  @telemetry.send_api_usage_telemetry(
@@ -544,24 +612,26 @@ class SelectKBest(BaseTransformer):
544
612
  # are specific to the type of dataset used.
545
613
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
546
614
 
615
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
616
+
547
617
  if isinstance(dataset, DataFrame):
548
- self._deps = self._batch_inference_validate_snowpark(
549
- dataset=dataset,
550
- inference_method=inference_method,
551
- )
552
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
618
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
619
+ self._deps = self._get_dependencies()
620
+ assert isinstance(
621
+ dataset._session, Session
622
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
553
623
  transform_kwargs = dict(
554
624
  session=dataset._session,
555
625
  dependencies=self._deps,
556
- drop_input_cols = self._drop_input_cols,
626
+ drop_input_cols=self._drop_input_cols,
557
627
  expected_output_cols_type="float",
558
628
  )
629
+ expected_output_cols = self._align_expected_output_names(
630
+ inference_method, dataset, expected_output_cols, output_cols_prefix
631
+ )
559
632
 
560
633
  elif isinstance(dataset, pd.DataFrame):
561
- transform_kwargs = dict(
562
- snowpark_input_cols = self._snowpark_cols,
563
- drop_input_cols = self._drop_input_cols
564
- )
634
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
565
635
 
566
636
  transform_handlers = ModelTransformerBuilder.build(
567
637
  dataset=dataset,
@@ -573,7 +643,7 @@ class SelectKBest(BaseTransformer):
573
643
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
574
644
  inference_method=inference_method,
575
645
  input_cols=self.input_cols,
576
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
646
+ expected_output_cols=expected_output_cols,
577
647
  **transform_kwargs
578
648
  )
579
649
  return output_df
@@ -603,29 +673,30 @@ class SelectKBest(BaseTransformer):
603
673
  Output dataset with log probability of the sample for each class in the model.
604
674
  """
605
675
  super()._check_dataset_type(dataset)
606
- inference_method="predict_log_proba"
676
+ inference_method = "predict_log_proba"
677
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
607
678
 
608
679
  # This dictionary contains optional kwargs for batch inference. These kwargs
609
680
  # are specific to the type of dataset used.
610
681
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
611
682
 
612
683
  if isinstance(dataset, DataFrame):
613
- self._deps = self._batch_inference_validate_snowpark(
614
- dataset=dataset,
615
- inference_method=inference_method,
616
- )
617
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
684
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
685
+ self._deps = self._get_dependencies()
686
+ assert isinstance(
687
+ dataset._session, Session
688
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
618
689
  transform_kwargs = dict(
619
690
  session=dataset._session,
620
691
  dependencies=self._deps,
621
- drop_input_cols = self._drop_input_cols,
692
+ drop_input_cols=self._drop_input_cols,
622
693
  expected_output_cols_type="float",
623
694
  )
695
+ expected_output_cols = self._align_expected_output_names(
696
+ inference_method, dataset, expected_output_cols, output_cols_prefix
697
+ )
624
698
  elif isinstance(dataset, pd.DataFrame):
625
- transform_kwargs = dict(
626
- snowpark_input_cols = self._snowpark_cols,
627
- drop_input_cols = self._drop_input_cols
628
- )
699
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
629
700
 
630
701
  transform_handlers = ModelTransformerBuilder.build(
631
702
  dataset=dataset,
@@ -638,7 +709,7 @@ class SelectKBest(BaseTransformer):
638
709
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
639
710
  inference_method=inference_method,
640
711
  input_cols=self.input_cols,
641
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
712
+ expected_output_cols=expected_output_cols,
642
713
  **transform_kwargs
643
714
  )
644
715
  return output_df
@@ -664,30 +735,32 @@ class SelectKBest(BaseTransformer):
664
735
  Output dataset with results of the decision function for the samples in input dataset.
665
736
  """
666
737
  super()._check_dataset_type(dataset)
667
- inference_method="decision_function"
738
+ inference_method = "decision_function"
668
739
 
669
740
  # This dictionary contains optional kwargs for batch inference. These kwargs
670
741
  # are specific to the type of dataset used.
671
742
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
672
743
 
744
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
745
+
673
746
  if isinstance(dataset, DataFrame):
674
- self._deps = self._batch_inference_validate_snowpark(
675
- dataset=dataset,
676
- inference_method=inference_method,
677
- )
678
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
747
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
748
+ self._deps = self._get_dependencies()
749
+ assert isinstance(
750
+ dataset._session, Session
751
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
679
752
  transform_kwargs = dict(
680
753
  session=dataset._session,
681
754
  dependencies=self._deps,
682
- drop_input_cols = self._drop_input_cols,
755
+ drop_input_cols=self._drop_input_cols,
683
756
  expected_output_cols_type="float",
684
757
  )
758
+ expected_output_cols = self._align_expected_output_names(
759
+ inference_method, dataset, expected_output_cols, output_cols_prefix
760
+ )
685
761
 
686
762
  elif isinstance(dataset, pd.DataFrame):
687
- transform_kwargs = dict(
688
- snowpark_input_cols = self._snowpark_cols,
689
- drop_input_cols = self._drop_input_cols
690
- )
763
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
691
764
 
692
765
  transform_handlers = ModelTransformerBuilder.build(
693
766
  dataset=dataset,
@@ -700,7 +773,7 @@ class SelectKBest(BaseTransformer):
700
773
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
701
774
  inference_method=inference_method,
702
775
  input_cols=self.input_cols,
703
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
776
+ expected_output_cols=expected_output_cols,
704
777
  **transform_kwargs
705
778
  )
706
779
  return output_df
@@ -729,17 +802,17 @@ class SelectKBest(BaseTransformer):
729
802
  Output dataset with probability of the sample for each class in the model.
730
803
  """
731
804
  super()._check_dataset_type(dataset)
732
- inference_method="score_samples"
805
+ inference_method = "score_samples"
733
806
 
734
807
  # This dictionary contains optional kwargs for batch inference. These kwargs
735
808
  # are specific to the type of dataset used.
736
809
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
737
810
 
811
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
812
+
738
813
  if isinstance(dataset, DataFrame):
739
- self._deps = self._batch_inference_validate_snowpark(
740
- dataset=dataset,
741
- inference_method=inference_method,
742
- )
814
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
815
+ self._deps = self._get_dependencies()
743
816
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
744
817
  transform_kwargs = dict(
745
818
  session=dataset._session,
@@ -747,6 +820,9 @@ class SelectKBest(BaseTransformer):
747
820
  drop_input_cols = self._drop_input_cols,
748
821
  expected_output_cols_type="float",
749
822
  )
823
+ expected_output_cols = self._align_expected_output_names(
824
+ inference_method, dataset, expected_output_cols, output_cols_prefix
825
+ )
750
826
 
751
827
  elif isinstance(dataset, pd.DataFrame):
752
828
  transform_kwargs = dict(
@@ -765,7 +841,7 @@ class SelectKBest(BaseTransformer):
765
841
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
766
842
  inference_method=inference_method,
767
843
  input_cols=self.input_cols,
768
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
844
+ expected_output_cols=expected_output_cols,
769
845
  **transform_kwargs
770
846
  )
771
847
  return output_df
@@ -798,17 +874,15 @@ class SelectKBest(BaseTransformer):
798
874
  transform_kwargs: ScoreKwargsTypedDict = dict()
799
875
 
800
876
  if isinstance(dataset, DataFrame):
801
- self._deps = self._batch_inference_validate_snowpark(
802
- dataset=dataset,
803
- inference_method="score",
804
- )
877
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
878
+ self._deps = self._get_dependencies()
805
879
  selected_cols = self._get_active_columns()
806
880
  if len(selected_cols) > 0:
807
881
  dataset = dataset.select(selected_cols)
808
882
  assert isinstance(dataset._session, Session) # keep mypy happy
809
883
  transform_kwargs = dict(
810
884
  session=dataset._session,
811
- dependencies=["snowflake-snowpark-python"] + self._deps,
885
+ dependencies=self._deps,
812
886
  score_sproc_imports=['sklearn'],
813
887
  )
814
888
  elif isinstance(dataset, pd.DataFrame):
@@ -873,11 +947,8 @@ class SelectKBest(BaseTransformer):
873
947
 
874
948
  if isinstance(dataset, DataFrame):
875
949
 
876
- self._deps = self._batch_inference_validate_snowpark(
877
- dataset=dataset,
878
- inference_method=inference_method,
879
-
880
- )
950
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
951
+ self._deps = self._get_dependencies()
881
952
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
882
953
  transform_kwargs = dict(
883
954
  session = dataset._session,
@@ -910,50 +981,84 @@ class SelectKBest(BaseTransformer):
910
981
  )
911
982
  return output_df
912
983
 
984
+
985
+
986
+ def to_sklearn(self) -> Any:
987
+ """Get sklearn.feature_selection.SelectKBest object.
988
+ """
989
+ if self._sklearn_object is None:
990
+ self._sklearn_object = self._create_sklearn_object()
991
+ return self._sklearn_object
992
+
993
+ def to_xgboost(self) -> Any:
994
+ raise exceptions.SnowflakeMLException(
995
+ error_code=error_codes.METHOD_NOT_ALLOWED,
996
+ original_exception=AttributeError(
997
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
998
+ "to_xgboost()",
999
+ "to_sklearn()"
1000
+ )
1001
+ ),
1002
+ )
913
1003
 
914
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1004
+ def to_lightgbm(self) -> Any:
1005
+ raise exceptions.SnowflakeMLException(
1006
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1007
+ original_exception=AttributeError(
1008
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1009
+ "to_lightgbm()",
1010
+ "to_sklearn()"
1011
+ )
1012
+ ),
1013
+ )
1014
+
1015
+ def _get_dependencies(self) -> List[str]:
1016
+ return self._deps
1017
+
1018
+
1019
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
915
1020
  self._model_signature_dict = dict()
916
1021
 
917
1022
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
918
1023
 
919
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1024
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
920
1025
  outputs: List[BaseFeatureSpec] = []
921
1026
  if hasattr(self, "predict"):
922
1027
  # keep mypy happy
923
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1028
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
924
1029
  # For classifier, the type of predict is the same as the type of label
925
- if self._sklearn_object._estimator_type == 'classifier':
926
- # label columns is the desired type for output
1030
+ if self._sklearn_object._estimator_type == "classifier":
1031
+ # label columns is the desired type for output
927
1032
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
928
1033
  # rename the output columns
929
1034
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
930
- self._model_signature_dict["predict"] = ModelSignature(inputs,
931
- ([] if self._drop_input_cols else inputs)
932
- + outputs)
1035
+ self._model_signature_dict["predict"] = ModelSignature(
1036
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1037
+ )
933
1038
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
934
1039
  # For outlier models, returns -1 for outliers and 1 for inliers.
935
- # Clusterer returns int64 cluster labels.
1040
+ # Clusterer returns int64 cluster labels.
936
1041
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
937
1042
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
938
- self._model_signature_dict["predict"] = ModelSignature(inputs,
939
- ([] if self._drop_input_cols else inputs)
940
- + outputs)
941
-
1043
+ self._model_signature_dict["predict"] = ModelSignature(
1044
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1045
+ )
1046
+
942
1047
  # For regressor, the type of predict is float64
943
- elif self._sklearn_object._estimator_type == 'regressor':
1048
+ elif self._sklearn_object._estimator_type == "regressor":
944
1049
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
945
- self._model_signature_dict["predict"] = ModelSignature(inputs,
946
- ([] if self._drop_input_cols else inputs)
947
- + outputs)
948
-
1050
+ self._model_signature_dict["predict"] = ModelSignature(
1051
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1052
+ )
1053
+
949
1054
  for prob_func in PROB_FUNCTIONS:
950
1055
  if hasattr(self, prob_func):
951
1056
  output_cols_prefix: str = f"{prob_func}_"
952
1057
  output_column_names = self._get_output_column_names(output_cols_prefix)
953
1058
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
954
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
955
- ([] if self._drop_input_cols else inputs)
956
- + outputs)
1059
+ self._model_signature_dict[prob_func] = ModelSignature(
1060
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1061
+ )
957
1062
 
958
1063
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
959
1064
  items = list(self._model_signature_dict.items())
@@ -966,10 +1071,10 @@ class SelectKBest(BaseTransformer):
966
1071
  """Returns model signature of current class.
967
1072
 
968
1073
  Raises:
969
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1074
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
970
1075
 
971
1076
  Returns:
972
- Dict[str, ModelSignature]: each method and its input output signature
1077
+ Dict with each method and its input output signature
973
1078
  """
974
1079
  if self._model_signature_dict is None:
975
1080
  raise exceptions.SnowflakeMLException(
@@ -977,35 +1082,3 @@ class SelectKBest(BaseTransformer):
977
1082
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
978
1083
  )
979
1084
  return self._model_signature_dict
980
-
981
- def to_sklearn(self) -> Any:
982
- """Get sklearn.feature_selection.SelectKBest object.
983
- """
984
- if self._sklearn_object is None:
985
- self._sklearn_object = self._create_sklearn_object()
986
- return self._sklearn_object
987
-
988
- def to_xgboost(self) -> Any:
989
- raise exceptions.SnowflakeMLException(
990
- error_code=error_codes.METHOD_NOT_ALLOWED,
991
- original_exception=AttributeError(
992
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
993
- "to_xgboost()",
994
- "to_sklearn()"
995
- )
996
- ),
997
- )
998
-
999
- def to_lightgbm(self) -> Any:
1000
- raise exceptions.SnowflakeMLException(
1001
- error_code=error_codes.METHOD_NOT_ALLOWED,
1002
- original_exception=AttributeError(
1003
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1004
- "to_lightgbm()",
1005
- "to_sklearn()"
1006
- )
1007
- ),
1008
- )
1009
-
1010
- def _get_dependencies(self) -> List[str]:
1011
- return self._deps