snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class AdaBoostRegressor(BaseTransformer):
58
70
  r"""An AdaBoost regressor
59
71
  For more details on this class, see [sklearn.ensemble.AdaBoostRegressor]
@@ -166,7 +178,9 @@ class AdaBoostRegressor(BaseTransformer):
166
178
  self.set_label_cols(label_cols)
167
179
  self.set_passthrough_cols(passthrough_cols)
168
180
  self.set_drop_input_cols(drop_input_cols)
169
- self.set_sample_weight_col(sample_weight_col)
181
+ self.set_sample_weight_col(sample_weight_col)
182
+ self._use_external_memory_version = False
183
+ self._batch_size = -1
170
184
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
171
185
  deps = deps | gather_dependencies(estimator)
172
186
  deps = deps | gather_dependencies(base_estimator)
@@ -249,11 +263,6 @@ class AdaBoostRegressor(BaseTransformer):
249
263
  if isinstance(dataset, DataFrame):
250
264
  session = dataset._session
251
265
  assert session is not None # keep mypy happy
252
- # Validate that key package version in user workspace are supported in snowflake conda channel
253
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
254
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
255
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
256
-
257
266
  # Specify input columns so column pruning will be enforced
258
267
  selected_cols = self._get_active_columns()
259
268
  if len(selected_cols) > 0:
@@ -281,7 +290,9 @@ class AdaBoostRegressor(BaseTransformer):
281
290
  label_cols=self.label_cols,
282
291
  sample_weight_col=self.sample_weight_col,
283
292
  autogenerated=self._autogenerated,
284
- subproject=_SUBPROJECT
293
+ subproject=_SUBPROJECT,
294
+ use_external_memory_version=self._use_external_memory_version,
295
+ batch_size=self._batch_size,
285
296
  )
286
297
  self._sklearn_object = model_trainer.train()
287
298
  self._is_fitted = True
@@ -552,6 +563,22 @@ class AdaBoostRegressor(BaseTransformer):
552
563
  # each row containing a list of values.
553
564
  expected_dtype = "ARRAY"
554
565
 
566
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
567
+ if expected_dtype == "":
568
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
569
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
570
+ expected_dtype = "ARRAY"
571
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
572
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
573
+ expected_dtype = "ARRAY"
574
+ else:
575
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
576
+ # We can only infer the output types from the input types if the following two statemetns are true:
577
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
578
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
579
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
580
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
581
+
555
582
  output_df = self._batch_inference(
556
583
  dataset=dataset,
557
584
  inference_method="transform",
@@ -567,8 +594,8 @@ class AdaBoostRegressor(BaseTransformer):
567
594
 
568
595
  return output_df
569
596
 
570
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
571
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
597
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
598
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
572
599
  """ Method not supported for this class.
573
600
 
574
601
 
@@ -581,13 +608,21 @@ class AdaBoostRegressor(BaseTransformer):
581
608
  Returns:
582
609
  Predicted dataset.
583
610
  """
584
- if False:
585
- self.fit(dataset)
586
- assert self._sklearn_object is not None
587
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
588
- return labels
589
- else:
590
- raise NotImplementedError
611
+ self.fit(dataset)
612
+ assert self._sklearn_object is not None
613
+ return self._sklearn_object.labels_
614
+
615
+
616
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
617
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
618
+ """
619
+ Returns:
620
+ Transformed dataset.
621
+ """
622
+ self.fit(dataset)
623
+ assert self._sklearn_object is not None
624
+ return self._sklearn_object.embedding_
625
+
591
626
 
592
627
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
593
628
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class BaggingClassifier(BaseTransformer):
58
70
  r"""A Bagging classifier
59
71
  For more details on this class, see [sklearn.ensemble.BaggingClassifier]
@@ -195,7 +207,9 @@ class BaggingClassifier(BaseTransformer):
195
207
  self.set_label_cols(label_cols)
196
208
  self.set_passthrough_cols(passthrough_cols)
197
209
  self.set_drop_input_cols(drop_input_cols)
198
- self.set_sample_weight_col(sample_weight_col)
210
+ self.set_sample_weight_col(sample_weight_col)
211
+ self._use_external_memory_version = False
212
+ self._batch_size = -1
199
213
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
200
214
  deps = deps | gather_dependencies(estimator)
201
215
  deps = deps | gather_dependencies(base_estimator)
@@ -284,11 +298,6 @@ class BaggingClassifier(BaseTransformer):
284
298
  if isinstance(dataset, DataFrame):
285
299
  session = dataset._session
286
300
  assert session is not None # keep mypy happy
287
- # Validate that key package version in user workspace are supported in snowflake conda channel
288
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
289
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
290
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
291
-
292
301
  # Specify input columns so column pruning will be enforced
293
302
  selected_cols = self._get_active_columns()
294
303
  if len(selected_cols) > 0:
@@ -316,7 +325,9 @@ class BaggingClassifier(BaseTransformer):
316
325
  label_cols=self.label_cols,
317
326
  sample_weight_col=self.sample_weight_col,
318
327
  autogenerated=self._autogenerated,
319
- subproject=_SUBPROJECT
328
+ subproject=_SUBPROJECT,
329
+ use_external_memory_version=self._use_external_memory_version,
330
+ batch_size=self._batch_size,
320
331
  )
321
332
  self._sklearn_object = model_trainer.train()
322
333
  self._is_fitted = True
@@ -587,6 +598,22 @@ class BaggingClassifier(BaseTransformer):
587
598
  # each row containing a list of values.
588
599
  expected_dtype = "ARRAY"
589
600
 
601
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
602
+ if expected_dtype == "":
603
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
604
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
605
+ expected_dtype = "ARRAY"
606
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
607
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
608
+ expected_dtype = "ARRAY"
609
+ else:
610
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
611
+ # We can only infer the output types from the input types if the following two statemetns are true:
612
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
613
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
614
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
615
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
616
+
590
617
  output_df = self._batch_inference(
591
618
  dataset=dataset,
592
619
  inference_method="transform",
@@ -602,8 +629,8 @@ class BaggingClassifier(BaseTransformer):
602
629
 
603
630
  return output_df
604
631
 
605
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
606
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
632
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
633
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
607
634
  """ Method not supported for this class.
608
635
 
609
636
 
@@ -616,13 +643,21 @@ class BaggingClassifier(BaseTransformer):
616
643
  Returns:
617
644
  Predicted dataset.
618
645
  """
619
- if False:
620
- self.fit(dataset)
621
- assert self._sklearn_object is not None
622
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
623
- return labels
624
- else:
625
- raise NotImplementedError
646
+ self.fit(dataset)
647
+ assert self._sklearn_object is not None
648
+ return self._sklearn_object.labels_
649
+
650
+
651
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
652
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
653
+ """
654
+ Returns:
655
+ Transformed dataset.
656
+ """
657
+ self.fit(dataset)
658
+ assert self._sklearn_object is not None
659
+ return self._sklearn_object.embedding_
660
+
626
661
 
627
662
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
628
663
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class BaggingRegressor(BaseTransformer):
58
70
  r"""A Bagging regressor
59
71
  For more details on this class, see [sklearn.ensemble.BaggingRegressor]
@@ -195,7 +207,9 @@ class BaggingRegressor(BaseTransformer):
195
207
  self.set_label_cols(label_cols)
196
208
  self.set_passthrough_cols(passthrough_cols)
197
209
  self.set_drop_input_cols(drop_input_cols)
198
- self.set_sample_weight_col(sample_weight_col)
210
+ self.set_sample_weight_col(sample_weight_col)
211
+ self._use_external_memory_version = False
212
+ self._batch_size = -1
199
213
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
200
214
  deps = deps | gather_dependencies(estimator)
201
215
  deps = deps | gather_dependencies(base_estimator)
@@ -284,11 +298,6 @@ class BaggingRegressor(BaseTransformer):
284
298
  if isinstance(dataset, DataFrame):
285
299
  session = dataset._session
286
300
  assert session is not None # keep mypy happy
287
- # Validate that key package version in user workspace are supported in snowflake conda channel
288
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
289
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
290
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
291
-
292
301
  # Specify input columns so column pruning will be enforced
293
302
  selected_cols = self._get_active_columns()
294
303
  if len(selected_cols) > 0:
@@ -316,7 +325,9 @@ class BaggingRegressor(BaseTransformer):
316
325
  label_cols=self.label_cols,
317
326
  sample_weight_col=self.sample_weight_col,
318
327
  autogenerated=self._autogenerated,
319
- subproject=_SUBPROJECT
328
+ subproject=_SUBPROJECT,
329
+ use_external_memory_version=self._use_external_memory_version,
330
+ batch_size=self._batch_size,
320
331
  )
321
332
  self._sklearn_object = model_trainer.train()
322
333
  self._is_fitted = True
@@ -587,6 +598,22 @@ class BaggingRegressor(BaseTransformer):
587
598
  # each row containing a list of values.
588
599
  expected_dtype = "ARRAY"
589
600
 
601
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
602
+ if expected_dtype == "":
603
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
604
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
605
+ expected_dtype = "ARRAY"
606
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
607
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
608
+ expected_dtype = "ARRAY"
609
+ else:
610
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
611
+ # We can only infer the output types from the input types if the following two statemetns are true:
612
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
613
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
614
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
615
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
616
+
590
617
  output_df = self._batch_inference(
591
618
  dataset=dataset,
592
619
  inference_method="transform",
@@ -602,8 +629,8 @@ class BaggingRegressor(BaseTransformer):
602
629
 
603
630
  return output_df
604
631
 
605
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
606
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
632
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
633
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
607
634
  """ Method not supported for this class.
608
635
 
609
636
 
@@ -616,13 +643,21 @@ class BaggingRegressor(BaseTransformer):
616
643
  Returns:
617
644
  Predicted dataset.
618
645
  """
619
- if False:
620
- self.fit(dataset)
621
- assert self._sklearn_object is not None
622
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
623
- return labels
624
- else:
625
- raise NotImplementedError
646
+ self.fit(dataset)
647
+ assert self._sklearn_object is not None
648
+ return self._sklearn_object.labels_
649
+
650
+
651
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
652
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
653
+ """
654
+ Returns:
655
+ Transformed dataset.
656
+ """
657
+ self.fit(dataset)
658
+ assert self._sklearn_object is not None
659
+ return self._sklearn_object.embedding_
660
+
626
661
 
627
662
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
628
663
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ExtraTreesClassifier(BaseTransformer):
58
70
  r"""An extra-trees classifier
59
71
  For more details on this class, see [sklearn.ensemble.ExtraTreesClassifier]
@@ -294,7 +306,9 @@ class ExtraTreesClassifier(BaseTransformer):
294
306
  self.set_label_cols(label_cols)
295
307
  self.set_passthrough_cols(passthrough_cols)
296
308
  self.set_drop_input_cols(drop_input_cols)
297
- self.set_sample_weight_col(sample_weight_col)
309
+ self.set_sample_weight_col(sample_weight_col)
310
+ self._use_external_memory_version = False
311
+ self._batch_size = -1
298
312
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
299
313
 
300
314
  self._deps = list(deps)
@@ -387,11 +401,6 @@ class ExtraTreesClassifier(BaseTransformer):
387
401
  if isinstance(dataset, DataFrame):
388
402
  session = dataset._session
389
403
  assert session is not None # keep mypy happy
390
- # Validate that key package version in user workspace are supported in snowflake conda channel
391
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
392
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
393
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
394
-
395
404
  # Specify input columns so column pruning will be enforced
396
405
  selected_cols = self._get_active_columns()
397
406
  if len(selected_cols) > 0:
@@ -419,7 +428,9 @@ class ExtraTreesClassifier(BaseTransformer):
419
428
  label_cols=self.label_cols,
420
429
  sample_weight_col=self.sample_weight_col,
421
430
  autogenerated=self._autogenerated,
422
- subproject=_SUBPROJECT
431
+ subproject=_SUBPROJECT,
432
+ use_external_memory_version=self._use_external_memory_version,
433
+ batch_size=self._batch_size,
423
434
  )
424
435
  self._sklearn_object = model_trainer.train()
425
436
  self._is_fitted = True
@@ -690,6 +701,22 @@ class ExtraTreesClassifier(BaseTransformer):
690
701
  # each row containing a list of values.
691
702
  expected_dtype = "ARRAY"
692
703
 
704
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
705
+ if expected_dtype == "":
706
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
707
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
708
+ expected_dtype = "ARRAY"
709
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
710
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
711
+ expected_dtype = "ARRAY"
712
+ else:
713
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
714
+ # We can only infer the output types from the input types if the following two statemetns are true:
715
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
716
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
717
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
718
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
719
+
693
720
  output_df = self._batch_inference(
694
721
  dataset=dataset,
695
722
  inference_method="transform",
@@ -705,8 +732,8 @@ class ExtraTreesClassifier(BaseTransformer):
705
732
 
706
733
  return output_df
707
734
 
708
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
709
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
735
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
736
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
710
737
  """ Method not supported for this class.
711
738
 
712
739
 
@@ -719,13 +746,21 @@ class ExtraTreesClassifier(BaseTransformer):
719
746
  Returns:
720
747
  Predicted dataset.
721
748
  """
722
- if False:
723
- self.fit(dataset)
724
- assert self._sklearn_object is not None
725
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
726
- return labels
727
- else:
728
- raise NotImplementedError
749
+ self.fit(dataset)
750
+ assert self._sklearn_object is not None
751
+ return self._sklearn_object.labels_
752
+
753
+
754
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
755
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
756
+ """
757
+ Returns:
758
+ Transformed dataset.
759
+ """
760
+ self.fit(dataset)
761
+ assert self._sklearn_object is not None
762
+ return self._sklearn_object.embedding_
763
+
729
764
 
730
765
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
731
766
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ExtraTreesRegressor(BaseTransformer):
58
70
  r"""An extra-trees regressor
59
71
  For more details on this class, see [sklearn.ensemble.ExtraTreesRegressor]
@@ -274,7 +286,9 @@ class ExtraTreesRegressor(BaseTransformer):
274
286
  self.set_label_cols(label_cols)
275
287
  self.set_passthrough_cols(passthrough_cols)
276
288
  self.set_drop_input_cols(drop_input_cols)
277
- self.set_sample_weight_col(sample_weight_col)
289
+ self.set_sample_weight_col(sample_weight_col)
290
+ self._use_external_memory_version = False
291
+ self._batch_size = -1
278
292
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
279
293
 
280
294
  self._deps = list(deps)
@@ -366,11 +380,6 @@ class ExtraTreesRegressor(BaseTransformer):
366
380
  if isinstance(dataset, DataFrame):
367
381
  session = dataset._session
368
382
  assert session is not None # keep mypy happy
369
- # Validate that key package version in user workspace are supported in snowflake conda channel
370
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
371
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
372
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
373
-
374
383
  # Specify input columns so column pruning will be enforced
375
384
  selected_cols = self._get_active_columns()
376
385
  if len(selected_cols) > 0:
@@ -398,7 +407,9 @@ class ExtraTreesRegressor(BaseTransformer):
398
407
  label_cols=self.label_cols,
399
408
  sample_weight_col=self.sample_weight_col,
400
409
  autogenerated=self._autogenerated,
401
- subproject=_SUBPROJECT
410
+ subproject=_SUBPROJECT,
411
+ use_external_memory_version=self._use_external_memory_version,
412
+ batch_size=self._batch_size,
402
413
  )
403
414
  self._sklearn_object = model_trainer.train()
404
415
  self._is_fitted = True
@@ -669,6 +680,22 @@ class ExtraTreesRegressor(BaseTransformer):
669
680
  # each row containing a list of values.
670
681
  expected_dtype = "ARRAY"
671
682
 
683
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
684
+ if expected_dtype == "":
685
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
686
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
687
+ expected_dtype = "ARRAY"
688
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
689
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
690
+ expected_dtype = "ARRAY"
691
+ else:
692
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
693
+ # We can only infer the output types from the input types if the following two statemetns are true:
694
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
695
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
696
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
697
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
698
+
672
699
  output_df = self._batch_inference(
673
700
  dataset=dataset,
674
701
  inference_method="transform",
@@ -684,8 +711,8 @@ class ExtraTreesRegressor(BaseTransformer):
684
711
 
685
712
  return output_df
686
713
 
687
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
688
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
714
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
715
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
689
716
  """ Method not supported for this class.
690
717
 
691
718
 
@@ -698,13 +725,21 @@ class ExtraTreesRegressor(BaseTransformer):
698
725
  Returns:
699
726
  Predicted dataset.
700
727
  """
701
- if False:
702
- self.fit(dataset)
703
- assert self._sklearn_object is not None
704
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
705
- return labels
706
- else:
707
- raise NotImplementedError
728
+ self.fit(dataset)
729
+ assert self._sklearn_object is not None
730
+ return self._sklearn_object.labels_
731
+
732
+
733
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
734
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
735
+ """
736
+ Returns:
737
+ Transformed dataset.
738
+ """
739
+ self.fit(dataset)
740
+ assert self._sklearn_object is not None
741
+ return self._sklearn_object.embedding_
742
+
708
743
 
709
744
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
710
745
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.