snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
55
55
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
56
56
 
57
57
 
58
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
59
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
60
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
61
+ return check
62
+
63
+
64
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
+ return check
68
+
69
+
58
70
  class GenericUnivariateSelect(BaseTransformer):
59
71
  r"""Univariate feature selector with configurable strategy
60
72
  For more details on this class, see [sklearn.feature_selection.GenericUnivariateSelect]
@@ -139,7 +151,9 @@ class GenericUnivariateSelect(BaseTransformer):
139
151
  self.set_label_cols(label_cols)
140
152
  self.set_passthrough_cols(passthrough_cols)
141
153
  self.set_drop_input_cols(drop_input_cols)
142
- self.set_sample_weight_col(sample_weight_col)
154
+ self.set_sample_weight_col(sample_weight_col)
155
+ self._use_external_memory_version = False
156
+ self._batch_size = -1
143
157
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
144
158
 
145
159
  self._deps = list(deps)
@@ -217,11 +231,6 @@ class GenericUnivariateSelect(BaseTransformer):
217
231
  if isinstance(dataset, DataFrame):
218
232
  session = dataset._session
219
233
  assert session is not None # keep mypy happy
220
- # Validate that key package version in user workspace are supported in snowflake conda channel
221
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
222
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
223
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
224
-
225
234
  # Specify input columns so column pruning will be enforced
226
235
  selected_cols = self._get_active_columns()
227
236
  if len(selected_cols) > 0:
@@ -249,7 +258,9 @@ class GenericUnivariateSelect(BaseTransformer):
249
258
  label_cols=self.label_cols,
250
259
  sample_weight_col=self.sample_weight_col,
251
260
  autogenerated=self._autogenerated,
252
- subproject=_SUBPROJECT
261
+ subproject=_SUBPROJECT,
262
+ use_external_memory_version=self._use_external_memory_version,
263
+ batch_size=self._batch_size,
253
264
  )
254
265
  self._sklearn_object = model_trainer.train()
255
266
  self._is_fitted = True
@@ -520,6 +531,22 @@ class GenericUnivariateSelect(BaseTransformer):
520
531
  # each row containing a list of values.
521
532
  expected_dtype = "ARRAY"
522
533
 
534
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
535
+ if expected_dtype == "":
536
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
537
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
538
+ expected_dtype = "ARRAY"
539
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
540
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
541
+ expected_dtype = "ARRAY"
542
+ else:
543
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
544
+ # We can only infer the output types from the input types if the following two statemetns are true:
545
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
546
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
547
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
548
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
549
+
523
550
  output_df = self._batch_inference(
524
551
  dataset=dataset,
525
552
  inference_method="transform",
@@ -535,8 +562,8 @@ class GenericUnivariateSelect(BaseTransformer):
535
562
 
536
563
  return output_df
537
564
 
538
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
539
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
565
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
566
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
540
567
  """ Method not supported for this class.
541
568
 
542
569
 
@@ -549,13 +576,21 @@ class GenericUnivariateSelect(BaseTransformer):
549
576
  Returns:
550
577
  Predicted dataset.
551
578
  """
552
- if False:
553
- self.fit(dataset)
554
- assert self._sklearn_object is not None
555
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
556
- return labels
557
- else:
558
- raise NotImplementedError
579
+ self.fit(dataset)
580
+ assert self._sklearn_object is not None
581
+ return self._sklearn_object.labels_
582
+
583
+
584
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
585
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
586
+ """
587
+ Returns:
588
+ Transformed dataset.
589
+ """
590
+ self.fit(dataset)
591
+ assert self._sklearn_object is not None
592
+ return self._sklearn_object.embedding_
593
+
559
594
 
560
595
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
561
596
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
55
55
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
56
56
 
57
57
 
58
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
59
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
60
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
61
+ return check
62
+
63
+
64
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
+ return check
68
+
69
+
58
70
  class SelectFdr(BaseTransformer):
59
71
  r"""Filter: Select the p-values for an estimated false discovery rate
60
72
  For more details on this class, see [sklearn.feature_selection.SelectFdr]
@@ -136,7 +148,9 @@ class SelectFdr(BaseTransformer):
136
148
  self.set_label_cols(label_cols)
137
149
  self.set_passthrough_cols(passthrough_cols)
138
150
  self.set_drop_input_cols(drop_input_cols)
139
- self.set_sample_weight_col(sample_weight_col)
151
+ self.set_sample_weight_col(sample_weight_col)
152
+ self._use_external_memory_version = False
153
+ self._batch_size = -1
140
154
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
141
155
 
142
156
  self._deps = list(deps)
@@ -213,11 +227,6 @@ class SelectFdr(BaseTransformer):
213
227
  if isinstance(dataset, DataFrame):
214
228
  session = dataset._session
215
229
  assert session is not None # keep mypy happy
216
- # Validate that key package version in user workspace are supported in snowflake conda channel
217
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
218
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
219
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
220
-
221
230
  # Specify input columns so column pruning will be enforced
222
231
  selected_cols = self._get_active_columns()
223
232
  if len(selected_cols) > 0:
@@ -245,7 +254,9 @@ class SelectFdr(BaseTransformer):
245
254
  label_cols=self.label_cols,
246
255
  sample_weight_col=self.sample_weight_col,
247
256
  autogenerated=self._autogenerated,
248
- subproject=_SUBPROJECT
257
+ subproject=_SUBPROJECT,
258
+ use_external_memory_version=self._use_external_memory_version,
259
+ batch_size=self._batch_size,
249
260
  )
250
261
  self._sklearn_object = model_trainer.train()
251
262
  self._is_fitted = True
@@ -516,6 +527,22 @@ class SelectFdr(BaseTransformer):
516
527
  # each row containing a list of values.
517
528
  expected_dtype = "ARRAY"
518
529
 
530
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
531
+ if expected_dtype == "":
532
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
533
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
534
+ expected_dtype = "ARRAY"
535
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
536
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
537
+ expected_dtype = "ARRAY"
538
+ else:
539
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
540
+ # We can only infer the output types from the input types if the following two statemetns are true:
541
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
542
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
543
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
544
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
545
+
519
546
  output_df = self._batch_inference(
520
547
  dataset=dataset,
521
548
  inference_method="transform",
@@ -531,8 +558,8 @@ class SelectFdr(BaseTransformer):
531
558
 
532
559
  return output_df
533
560
 
534
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
535
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
561
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
562
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
536
563
  """ Method not supported for this class.
537
564
 
538
565
 
@@ -545,13 +572,21 @@ class SelectFdr(BaseTransformer):
545
572
  Returns:
546
573
  Predicted dataset.
547
574
  """
548
- if False:
549
- self.fit(dataset)
550
- assert self._sklearn_object is not None
551
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
552
- return labels
553
- else:
554
- raise NotImplementedError
575
+ self.fit(dataset)
576
+ assert self._sklearn_object is not None
577
+ return self._sklearn_object.labels_
578
+
579
+
580
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
581
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
582
+ """
583
+ Returns:
584
+ Transformed dataset.
585
+ """
586
+ self.fit(dataset)
587
+ assert self._sklearn_object is not None
588
+ return self._sklearn_object.embedding_
589
+
555
590
 
556
591
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
557
592
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
55
55
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
56
56
 
57
57
 
58
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
59
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
60
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
61
+ return check
62
+
63
+
64
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
+ return check
68
+
69
+
58
70
  class SelectFpr(BaseTransformer):
59
71
  r"""Filter: Select the pvalues below alpha based on a FPR test
60
72
  For more details on this class, see [sklearn.feature_selection.SelectFpr]
@@ -136,7 +148,9 @@ class SelectFpr(BaseTransformer):
136
148
  self.set_label_cols(label_cols)
137
149
  self.set_passthrough_cols(passthrough_cols)
138
150
  self.set_drop_input_cols(drop_input_cols)
139
- self.set_sample_weight_col(sample_weight_col)
151
+ self.set_sample_weight_col(sample_weight_col)
152
+ self._use_external_memory_version = False
153
+ self._batch_size = -1
140
154
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
141
155
 
142
156
  self._deps = list(deps)
@@ -213,11 +227,6 @@ class SelectFpr(BaseTransformer):
213
227
  if isinstance(dataset, DataFrame):
214
228
  session = dataset._session
215
229
  assert session is not None # keep mypy happy
216
- # Validate that key package version in user workspace are supported in snowflake conda channel
217
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
218
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
219
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
220
-
221
230
  # Specify input columns so column pruning will be enforced
222
231
  selected_cols = self._get_active_columns()
223
232
  if len(selected_cols) > 0:
@@ -245,7 +254,9 @@ class SelectFpr(BaseTransformer):
245
254
  label_cols=self.label_cols,
246
255
  sample_weight_col=self.sample_weight_col,
247
256
  autogenerated=self._autogenerated,
248
- subproject=_SUBPROJECT
257
+ subproject=_SUBPROJECT,
258
+ use_external_memory_version=self._use_external_memory_version,
259
+ batch_size=self._batch_size,
249
260
  )
250
261
  self._sklearn_object = model_trainer.train()
251
262
  self._is_fitted = True
@@ -516,6 +527,22 @@ class SelectFpr(BaseTransformer):
516
527
  # each row containing a list of values.
517
528
  expected_dtype = "ARRAY"
518
529
 
530
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
531
+ if expected_dtype == "":
532
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
533
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
534
+ expected_dtype = "ARRAY"
535
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
536
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
537
+ expected_dtype = "ARRAY"
538
+ else:
539
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
540
+ # We can only infer the output types from the input types if the following two statemetns are true:
541
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
542
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
543
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
544
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
545
+
519
546
  output_df = self._batch_inference(
520
547
  dataset=dataset,
521
548
  inference_method="transform",
@@ -531,8 +558,8 @@ class SelectFpr(BaseTransformer):
531
558
 
532
559
  return output_df
533
560
 
534
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
535
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
561
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
562
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
536
563
  """ Method not supported for this class.
537
564
 
538
565
 
@@ -545,13 +572,21 @@ class SelectFpr(BaseTransformer):
545
572
  Returns:
546
573
  Predicted dataset.
547
574
  """
548
- if False:
549
- self.fit(dataset)
550
- assert self._sklearn_object is not None
551
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
552
- return labels
553
- else:
554
- raise NotImplementedError
575
+ self.fit(dataset)
576
+ assert self._sklearn_object is not None
577
+ return self._sklearn_object.labels_
578
+
579
+
580
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
581
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
582
+ """
583
+ Returns:
584
+ Transformed dataset.
585
+ """
586
+ self.fit(dataset)
587
+ assert self._sklearn_object is not None
588
+ return self._sklearn_object.embedding_
589
+
555
590
 
556
591
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
557
592
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
55
55
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
56
56
 
57
57
 
58
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
59
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
60
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
61
+ return check
62
+
63
+
64
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
+ return check
68
+
69
+
58
70
  class SelectFwe(BaseTransformer):
59
71
  r"""Filter: Select the p-values corresponding to Family-wise error rate
60
72
  For more details on this class, see [sklearn.feature_selection.SelectFwe]
@@ -136,7 +148,9 @@ class SelectFwe(BaseTransformer):
136
148
  self.set_label_cols(label_cols)
137
149
  self.set_passthrough_cols(passthrough_cols)
138
150
  self.set_drop_input_cols(drop_input_cols)
139
- self.set_sample_weight_col(sample_weight_col)
151
+ self.set_sample_weight_col(sample_weight_col)
152
+ self._use_external_memory_version = False
153
+ self._batch_size = -1
140
154
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
141
155
 
142
156
  self._deps = list(deps)
@@ -213,11 +227,6 @@ class SelectFwe(BaseTransformer):
213
227
  if isinstance(dataset, DataFrame):
214
228
  session = dataset._session
215
229
  assert session is not None # keep mypy happy
216
- # Validate that key package version in user workspace are supported in snowflake conda channel
217
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
218
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
219
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
220
-
221
230
  # Specify input columns so column pruning will be enforced
222
231
  selected_cols = self._get_active_columns()
223
232
  if len(selected_cols) > 0:
@@ -245,7 +254,9 @@ class SelectFwe(BaseTransformer):
245
254
  label_cols=self.label_cols,
246
255
  sample_weight_col=self.sample_weight_col,
247
256
  autogenerated=self._autogenerated,
248
- subproject=_SUBPROJECT
257
+ subproject=_SUBPROJECT,
258
+ use_external_memory_version=self._use_external_memory_version,
259
+ batch_size=self._batch_size,
249
260
  )
250
261
  self._sklearn_object = model_trainer.train()
251
262
  self._is_fitted = True
@@ -516,6 +527,22 @@ class SelectFwe(BaseTransformer):
516
527
  # each row containing a list of values.
517
528
  expected_dtype = "ARRAY"
518
529
 
530
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
531
+ if expected_dtype == "":
532
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
533
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
534
+ expected_dtype = "ARRAY"
535
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
536
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
537
+ expected_dtype = "ARRAY"
538
+ else:
539
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
540
+ # We can only infer the output types from the input types if the following two statemetns are true:
541
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
542
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
543
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
544
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
545
+
519
546
  output_df = self._batch_inference(
520
547
  dataset=dataset,
521
548
  inference_method="transform",
@@ -531,8 +558,8 @@ class SelectFwe(BaseTransformer):
531
558
 
532
559
  return output_df
533
560
 
534
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
535
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
561
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
562
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
536
563
  """ Method not supported for this class.
537
564
 
538
565
 
@@ -545,13 +572,21 @@ class SelectFwe(BaseTransformer):
545
572
  Returns:
546
573
  Predicted dataset.
547
574
  """
548
- if False:
549
- self.fit(dataset)
550
- assert self._sklearn_object is not None
551
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
552
- return labels
553
- else:
554
- raise NotImplementedError
575
+ self.fit(dataset)
576
+ assert self._sklearn_object is not None
577
+ return self._sklearn_object.labels_
578
+
579
+
580
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
581
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
582
+ """
583
+ Returns:
584
+ Transformed dataset.
585
+ """
586
+ self.fit(dataset)
587
+ assert self._sklearn_object is not None
588
+ return self._sklearn_object.embedding_
589
+
555
590
 
556
591
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
557
592
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
55
55
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
56
56
 
57
57
 
58
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
59
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
60
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
61
+ return check
62
+
63
+
64
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
+ return check
68
+
69
+
58
70
  class SelectKBest(BaseTransformer):
59
71
  r"""Select features according to the k highest scores
60
72
  For more details on this class, see [sklearn.feature_selection.SelectKBest]
@@ -137,7 +149,9 @@ class SelectKBest(BaseTransformer):
137
149
  self.set_label_cols(label_cols)
138
150
  self.set_passthrough_cols(passthrough_cols)
139
151
  self.set_drop_input_cols(drop_input_cols)
140
- self.set_sample_weight_col(sample_weight_col)
152
+ self.set_sample_weight_col(sample_weight_col)
153
+ self._use_external_memory_version = False
154
+ self._batch_size = -1
141
155
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
142
156
 
143
157
  self._deps = list(deps)
@@ -214,11 +228,6 @@ class SelectKBest(BaseTransformer):
214
228
  if isinstance(dataset, DataFrame):
215
229
  session = dataset._session
216
230
  assert session is not None # keep mypy happy
217
- # Validate that key package version in user workspace are supported in snowflake conda channel
218
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
219
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
220
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
221
-
222
231
  # Specify input columns so column pruning will be enforced
223
232
  selected_cols = self._get_active_columns()
224
233
  if len(selected_cols) > 0:
@@ -246,7 +255,9 @@ class SelectKBest(BaseTransformer):
246
255
  label_cols=self.label_cols,
247
256
  sample_weight_col=self.sample_weight_col,
248
257
  autogenerated=self._autogenerated,
249
- subproject=_SUBPROJECT
258
+ subproject=_SUBPROJECT,
259
+ use_external_memory_version=self._use_external_memory_version,
260
+ batch_size=self._batch_size,
250
261
  )
251
262
  self._sklearn_object = model_trainer.train()
252
263
  self._is_fitted = True
@@ -517,6 +528,22 @@ class SelectKBest(BaseTransformer):
517
528
  # each row containing a list of values.
518
529
  expected_dtype = "ARRAY"
519
530
 
531
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
532
+ if expected_dtype == "":
533
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
534
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
535
+ expected_dtype = "ARRAY"
536
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
537
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
538
+ expected_dtype = "ARRAY"
539
+ else:
540
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
541
+ # We can only infer the output types from the input types if the following two statemetns are true:
542
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
543
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
544
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
545
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
546
+
520
547
  output_df = self._batch_inference(
521
548
  dataset=dataset,
522
549
  inference_method="transform",
@@ -532,8 +559,8 @@ class SelectKBest(BaseTransformer):
532
559
 
533
560
  return output_df
534
561
 
535
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
536
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
562
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
563
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
537
564
  """ Method not supported for this class.
538
565
 
539
566
 
@@ -546,13 +573,21 @@ class SelectKBest(BaseTransformer):
546
573
  Returns:
547
574
  Predicted dataset.
548
575
  """
549
- if False:
550
- self.fit(dataset)
551
- assert self._sklearn_object is not None
552
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
553
- return labels
554
- else:
555
- raise NotImplementedError
576
+ self.fit(dataset)
577
+ assert self._sklearn_object is not None
578
+ return self._sklearn_object.labels_
579
+
580
+
581
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
582
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
583
+ """
584
+ Returns:
585
+ Transformed dataset.
586
+ """
587
+ self.fit(dataset)
588
+ assert self._sklearn_object is not None
589
+ return self._sklearn_object.embedding_
590
+
556
591
 
557
592
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
558
593
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.