snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SparsePCA(BaseTransformer):
58
70
  r"""Sparse Principal Components Analysis (SparsePCA)
59
71
  For more details on this class, see [sklearn.decomposition.SparsePCA]
@@ -181,7 +193,9 @@ class SparsePCA(BaseTransformer):
181
193
  self.set_label_cols(label_cols)
182
194
  self.set_passthrough_cols(passthrough_cols)
183
195
  self.set_drop_input_cols(drop_input_cols)
184
- self.set_sample_weight_col(sample_weight_col)
196
+ self.set_sample_weight_col(sample_weight_col)
197
+ self._use_external_memory_version = False
198
+ self._batch_size = -1
185
199
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
186
200
 
187
201
  self._deps = list(deps)
@@ -267,11 +281,6 @@ class SparsePCA(BaseTransformer):
267
281
  if isinstance(dataset, DataFrame):
268
282
  session = dataset._session
269
283
  assert session is not None # keep mypy happy
270
- # Validate that key package version in user workspace are supported in snowflake conda channel
271
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
272
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
273
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
274
-
275
284
  # Specify input columns so column pruning will be enforced
276
285
  selected_cols = self._get_active_columns()
277
286
  if len(selected_cols) > 0:
@@ -299,7 +308,9 @@ class SparsePCA(BaseTransformer):
299
308
  label_cols=self.label_cols,
300
309
  sample_weight_col=self.sample_weight_col,
301
310
  autogenerated=self._autogenerated,
302
- subproject=_SUBPROJECT
311
+ subproject=_SUBPROJECT,
312
+ use_external_memory_version=self._use_external_memory_version,
313
+ batch_size=self._batch_size,
303
314
  )
304
315
  self._sklearn_object = model_trainer.train()
305
316
  self._is_fitted = True
@@ -570,6 +581,22 @@ class SparsePCA(BaseTransformer):
570
581
  # each row containing a list of values.
571
582
  expected_dtype = "ARRAY"
572
583
 
584
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
585
+ if expected_dtype == "":
586
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
587
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
588
+ expected_dtype = "ARRAY"
589
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
590
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
591
+ expected_dtype = "ARRAY"
592
+ else:
593
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
594
+ # We can only infer the output types from the input types if the following two statemetns are true:
595
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
596
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
597
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
598
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
599
+
573
600
  output_df = self._batch_inference(
574
601
  dataset=dataset,
575
602
  inference_method="transform",
@@ -585,8 +612,8 @@ class SparsePCA(BaseTransformer):
585
612
 
586
613
  return output_df
587
614
 
588
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
589
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
615
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
616
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
590
617
  """ Method not supported for this class.
591
618
 
592
619
 
@@ -599,13 +626,21 @@ class SparsePCA(BaseTransformer):
599
626
  Returns:
600
627
  Predicted dataset.
601
628
  """
602
- if False:
603
- self.fit(dataset)
604
- assert self._sklearn_object is not None
605
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
606
- return labels
607
- else:
608
- raise NotImplementedError
629
+ self.fit(dataset)
630
+ assert self._sklearn_object is not None
631
+ return self._sklearn_object.labels_
632
+
633
+
634
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
635
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
636
+ """
637
+ Returns:
638
+ Transformed dataset.
639
+ """
640
+ self.fit(dataset)
641
+ assert self._sklearn_object is not None
642
+ return self._sklearn_object.embedding_
643
+
609
644
 
610
645
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
611
646
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class TruncatedSVD(BaseTransformer):
58
70
  r"""Dimensionality reduction using truncated SVD (aka LSA)
59
71
  For more details on this class, see [sklearn.decomposition.TruncatedSVD]
@@ -166,7 +178,9 @@ class TruncatedSVD(BaseTransformer):
166
178
  self.set_label_cols(label_cols)
167
179
  self.set_passthrough_cols(passthrough_cols)
168
180
  self.set_drop_input_cols(drop_input_cols)
169
- self.set_sample_weight_col(sample_weight_col)
181
+ self.set_sample_weight_col(sample_weight_col)
182
+ self._use_external_memory_version = False
183
+ self._batch_size = -1
170
184
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
171
185
 
172
186
  self._deps = list(deps)
@@ -248,11 +262,6 @@ class TruncatedSVD(BaseTransformer):
248
262
  if isinstance(dataset, DataFrame):
249
263
  session = dataset._session
250
264
  assert session is not None # keep mypy happy
251
- # Validate that key package version in user workspace are supported in snowflake conda channel
252
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
253
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
254
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
255
-
256
265
  # Specify input columns so column pruning will be enforced
257
266
  selected_cols = self._get_active_columns()
258
267
  if len(selected_cols) > 0:
@@ -280,7 +289,9 @@ class TruncatedSVD(BaseTransformer):
280
289
  label_cols=self.label_cols,
281
290
  sample_weight_col=self.sample_weight_col,
282
291
  autogenerated=self._autogenerated,
283
- subproject=_SUBPROJECT
292
+ subproject=_SUBPROJECT,
293
+ use_external_memory_version=self._use_external_memory_version,
294
+ batch_size=self._batch_size,
284
295
  )
285
296
  self._sklearn_object = model_trainer.train()
286
297
  self._is_fitted = True
@@ -551,6 +562,22 @@ class TruncatedSVD(BaseTransformer):
551
562
  # each row containing a list of values.
552
563
  expected_dtype = "ARRAY"
553
564
 
565
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
566
+ if expected_dtype == "":
567
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
568
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
569
+ expected_dtype = "ARRAY"
570
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
571
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
572
+ expected_dtype = "ARRAY"
573
+ else:
574
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
575
+ # We can only infer the output types from the input types if the following two statemetns are true:
576
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
577
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
578
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
579
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
580
+
554
581
  output_df = self._batch_inference(
555
582
  dataset=dataset,
556
583
  inference_method="transform",
@@ -566,8 +593,8 @@ class TruncatedSVD(BaseTransformer):
566
593
 
567
594
  return output_df
568
595
 
569
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
570
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
596
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
597
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
571
598
  """ Method not supported for this class.
572
599
 
573
600
 
@@ -580,13 +607,21 @@ class TruncatedSVD(BaseTransformer):
580
607
  Returns:
581
608
  Predicted dataset.
582
609
  """
583
- if False:
584
- self.fit(dataset)
585
- assert self._sklearn_object is not None
586
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
587
- return labels
588
- else:
589
- raise NotImplementedError
610
+ self.fit(dataset)
611
+ assert self._sklearn_object is not None
612
+ return self._sklearn_object.labels_
613
+
614
+
615
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
616
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
617
+ """
618
+ Returns:
619
+ Transformed dataset.
620
+ """
621
+ self.fit(dataset)
622
+ assert self._sklearn_object is not None
623
+ return self._sklearn_object.embedding_
624
+
590
625
 
591
626
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
592
627
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.discriminant_analysis".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LinearDiscriminantAnalysis(BaseTransformer):
58
70
  r"""Linear Discriminant Analysis
59
71
  For more details on this class, see [sklearn.discriminant_analysis.LinearDiscriminantAnalysis]
@@ -183,7 +195,9 @@ class LinearDiscriminantAnalysis(BaseTransformer):
183
195
  self.set_label_cols(label_cols)
184
196
  self.set_passthrough_cols(passthrough_cols)
185
197
  self.set_drop_input_cols(drop_input_cols)
186
- self.set_sample_weight_col(sample_weight_col)
198
+ self.set_sample_weight_col(sample_weight_col)
199
+ self._use_external_memory_version = False
200
+ self._batch_size = -1
187
201
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
188
202
 
189
203
  self._deps = list(deps)
@@ -265,11 +279,6 @@ class LinearDiscriminantAnalysis(BaseTransformer):
265
279
  if isinstance(dataset, DataFrame):
266
280
  session = dataset._session
267
281
  assert session is not None # keep mypy happy
268
- # Validate that key package version in user workspace are supported in snowflake conda channel
269
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
270
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
271
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
272
-
273
282
  # Specify input columns so column pruning will be enforced
274
283
  selected_cols = self._get_active_columns()
275
284
  if len(selected_cols) > 0:
@@ -297,7 +306,9 @@ class LinearDiscriminantAnalysis(BaseTransformer):
297
306
  label_cols=self.label_cols,
298
307
  sample_weight_col=self.sample_weight_col,
299
308
  autogenerated=self._autogenerated,
300
- subproject=_SUBPROJECT
309
+ subproject=_SUBPROJECT,
310
+ use_external_memory_version=self._use_external_memory_version,
311
+ batch_size=self._batch_size,
301
312
  )
302
313
  self._sklearn_object = model_trainer.train()
303
314
  self._is_fitted = True
@@ -570,6 +581,22 @@ class LinearDiscriminantAnalysis(BaseTransformer):
570
581
  # each row containing a list of values.
571
582
  expected_dtype = "ARRAY"
572
583
 
584
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
585
+ if expected_dtype == "":
586
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
587
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
588
+ expected_dtype = "ARRAY"
589
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
590
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
591
+ expected_dtype = "ARRAY"
592
+ else:
593
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
594
+ # We can only infer the output types from the input types if the following two statemetns are true:
595
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
596
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
597
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
598
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
599
+
573
600
  output_df = self._batch_inference(
574
601
  dataset=dataset,
575
602
  inference_method="transform",
@@ -585,8 +612,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
585
612
 
586
613
  return output_df
587
614
 
588
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
589
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
615
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
616
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
590
617
  """ Method not supported for this class.
591
618
 
592
619
 
@@ -599,13 +626,21 @@ class LinearDiscriminantAnalysis(BaseTransformer):
599
626
  Returns:
600
627
  Predicted dataset.
601
628
  """
602
- if False:
603
- self.fit(dataset)
604
- assert self._sklearn_object is not None
605
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
606
- return labels
607
- else:
608
- raise NotImplementedError
629
+ self.fit(dataset)
630
+ assert self._sklearn_object is not None
631
+ return self._sklearn_object.labels_
632
+
633
+
634
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
635
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
636
+ """
637
+ Returns:
638
+ Transformed dataset.
639
+ """
640
+ self.fit(dataset)
641
+ assert self._sklearn_object is not None
642
+ return self._sklearn_object.embedding_
643
+
609
644
 
610
645
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
611
646
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.discriminant_analysis".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class QuadraticDiscriminantAnalysis(BaseTransformer):
58
70
  r"""Quadratic Discriminant Analysis
59
71
  For more details on this class, see [sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis]
@@ -148,7 +160,9 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
148
160
  self.set_label_cols(label_cols)
149
161
  self.set_passthrough_cols(passthrough_cols)
150
162
  self.set_drop_input_cols(drop_input_cols)
151
- self.set_sample_weight_col(sample_weight_col)
163
+ self.set_sample_weight_col(sample_weight_col)
164
+ self._use_external_memory_version = False
165
+ self._batch_size = -1
152
166
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
153
167
 
154
168
  self._deps = list(deps)
@@ -227,11 +241,6 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
227
241
  if isinstance(dataset, DataFrame):
228
242
  session = dataset._session
229
243
  assert session is not None # keep mypy happy
230
- # Validate that key package version in user workspace are supported in snowflake conda channel
231
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
232
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
233
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
234
-
235
244
  # Specify input columns so column pruning will be enforced
236
245
  selected_cols = self._get_active_columns()
237
246
  if len(selected_cols) > 0:
@@ -259,7 +268,9 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
259
268
  label_cols=self.label_cols,
260
269
  sample_weight_col=self.sample_weight_col,
261
270
  autogenerated=self._autogenerated,
262
- subproject=_SUBPROJECT
271
+ subproject=_SUBPROJECT,
272
+ use_external_memory_version=self._use_external_memory_version,
273
+ batch_size=self._batch_size,
263
274
  )
264
275
  self._sklearn_object = model_trainer.train()
265
276
  self._is_fitted = True
@@ -530,6 +541,22 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
530
541
  # each row containing a list of values.
531
542
  expected_dtype = "ARRAY"
532
543
 
544
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
545
+ if expected_dtype == "":
546
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
547
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
548
+ expected_dtype = "ARRAY"
549
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
550
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ else:
553
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
554
+ # We can only infer the output types from the input types if the following two statemetns are true:
555
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
556
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
557
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
558
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
559
+
533
560
  output_df = self._batch_inference(
534
561
  dataset=dataset,
535
562
  inference_method="transform",
@@ -545,8 +572,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
545
572
 
546
573
  return output_df
547
574
 
548
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
549
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
575
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
576
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
550
577
  """ Method not supported for this class.
551
578
 
552
579
 
@@ -559,13 +586,21 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
559
586
  Returns:
560
587
  Predicted dataset.
561
588
  """
562
- if False:
563
- self.fit(dataset)
564
- assert self._sklearn_object is not None
565
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
566
- return labels
567
- else:
568
- raise NotImplementedError
589
+ self.fit(dataset)
590
+ assert self._sklearn_object is not None
591
+ return self._sklearn_object.labels_
592
+
593
+
594
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
595
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
596
+ """
597
+ Returns:
598
+ Transformed dataset.
599
+ """
600
+ self.fit(dataset)
601
+ assert self._sklearn_object is not None
602
+ return self._sklearn_object.embedding_
603
+
569
604
 
570
605
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
571
606
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class AdaBoostClassifier(BaseTransformer):
58
70
  r"""An AdaBoost classifier
59
71
  For more details on this class, see [sklearn.ensemble.AdaBoostClassifier]
@@ -169,7 +181,9 @@ class AdaBoostClassifier(BaseTransformer):
169
181
  self.set_label_cols(label_cols)
170
182
  self.set_passthrough_cols(passthrough_cols)
171
183
  self.set_drop_input_cols(drop_input_cols)
172
- self.set_sample_weight_col(sample_weight_col)
184
+ self.set_sample_weight_col(sample_weight_col)
185
+ self._use_external_memory_version = False
186
+ self._batch_size = -1
173
187
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
174
188
  deps = deps | gather_dependencies(estimator)
175
189
  deps = deps | gather_dependencies(base_estimator)
@@ -252,11 +266,6 @@ class AdaBoostClassifier(BaseTransformer):
252
266
  if isinstance(dataset, DataFrame):
253
267
  session = dataset._session
254
268
  assert session is not None # keep mypy happy
255
- # Validate that key package version in user workspace are supported in snowflake conda channel
256
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
257
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
258
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
259
-
260
269
  # Specify input columns so column pruning will be enforced
261
270
  selected_cols = self._get_active_columns()
262
271
  if len(selected_cols) > 0:
@@ -284,7 +293,9 @@ class AdaBoostClassifier(BaseTransformer):
284
293
  label_cols=self.label_cols,
285
294
  sample_weight_col=self.sample_weight_col,
286
295
  autogenerated=self._autogenerated,
287
- subproject=_SUBPROJECT
296
+ subproject=_SUBPROJECT,
297
+ use_external_memory_version=self._use_external_memory_version,
298
+ batch_size=self._batch_size,
288
299
  )
289
300
  self._sklearn_object = model_trainer.train()
290
301
  self._is_fitted = True
@@ -555,6 +566,22 @@ class AdaBoostClassifier(BaseTransformer):
555
566
  # each row containing a list of values.
556
567
  expected_dtype = "ARRAY"
557
568
 
569
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
570
+ if expected_dtype == "":
571
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
572
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
573
+ expected_dtype = "ARRAY"
574
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
575
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
576
+ expected_dtype = "ARRAY"
577
+ else:
578
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
579
+ # We can only infer the output types from the input types if the following two statemetns are true:
580
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
581
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
582
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
583
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
584
+
558
585
  output_df = self._batch_inference(
559
586
  dataset=dataset,
560
587
  inference_method="transform",
@@ -570,8 +597,8 @@ class AdaBoostClassifier(BaseTransformer):
570
597
 
571
598
  return output_df
572
599
 
573
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
574
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
600
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
601
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
575
602
  """ Method not supported for this class.
576
603
 
577
604
 
@@ -584,13 +611,21 @@ class AdaBoostClassifier(BaseTransformer):
584
611
  Returns:
585
612
  Predicted dataset.
586
613
  """
587
- if False:
588
- self.fit(dataset)
589
- assert self._sklearn_object is not None
590
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
591
- return labels
592
- else:
593
- raise NotImplementedError
614
+ self.fit(dataset)
615
+ assert self._sklearn_object is not None
616
+ return self._sklearn_object.labels_
617
+
618
+
619
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
620
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
621
+ """
622
+ Returns:
623
+ Transformed dataset.
624
+ """
625
+ self.fit(dataset)
626
+ assert self._sklearn_object is not None
627
+ return self._sklearn_object.embedding_
628
+
594
629
 
595
630
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
596
631
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.