snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RandomForestClassifier(BaseTransformer):
58
70
  r"""A random forest classifier
59
71
  For more details on this class, see [sklearn.ensemble.RandomForestClassifier]
@@ -290,7 +302,9 @@ class RandomForestClassifier(BaseTransformer):
290
302
  self.set_label_cols(label_cols)
291
303
  self.set_passthrough_cols(passthrough_cols)
292
304
  self.set_drop_input_cols(drop_input_cols)
293
- self.set_sample_weight_col(sample_weight_col)
305
+ self.set_sample_weight_col(sample_weight_col)
306
+ self._use_external_memory_version = False
307
+ self._batch_size = -1
294
308
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
295
309
 
296
310
  self._deps = list(deps)
@@ -383,11 +397,6 @@ class RandomForestClassifier(BaseTransformer):
383
397
  if isinstance(dataset, DataFrame):
384
398
  session = dataset._session
385
399
  assert session is not None # keep mypy happy
386
- # Validate that key package version in user workspace are supported in snowflake conda channel
387
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
388
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
389
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
390
-
391
400
  # Specify input columns so column pruning will be enforced
392
401
  selected_cols = self._get_active_columns()
393
402
  if len(selected_cols) > 0:
@@ -415,7 +424,9 @@ class RandomForestClassifier(BaseTransformer):
415
424
  label_cols=self.label_cols,
416
425
  sample_weight_col=self.sample_weight_col,
417
426
  autogenerated=self._autogenerated,
418
- subproject=_SUBPROJECT
427
+ subproject=_SUBPROJECT,
428
+ use_external_memory_version=self._use_external_memory_version,
429
+ batch_size=self._batch_size,
419
430
  )
420
431
  self._sklearn_object = model_trainer.train()
421
432
  self._is_fitted = True
@@ -686,6 +697,22 @@ class RandomForestClassifier(BaseTransformer):
686
697
  # each row containing a list of values.
687
698
  expected_dtype = "ARRAY"
688
699
 
700
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
701
+ if expected_dtype == "":
702
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
703
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
704
+ expected_dtype = "ARRAY"
705
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
706
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
707
+ expected_dtype = "ARRAY"
708
+ else:
709
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
710
+ # We can only infer the output types from the input types if the following two statemetns are true:
711
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
712
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
713
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
714
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
715
+
689
716
  output_df = self._batch_inference(
690
717
  dataset=dataset,
691
718
  inference_method="transform",
@@ -701,8 +728,8 @@ class RandomForestClassifier(BaseTransformer):
701
728
 
702
729
  return output_df
703
730
 
704
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
705
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
731
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
732
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
706
733
  """ Method not supported for this class.
707
734
 
708
735
 
@@ -715,13 +742,21 @@ class RandomForestClassifier(BaseTransformer):
715
742
  Returns:
716
743
  Predicted dataset.
717
744
  """
718
- if False:
719
- self.fit(dataset)
720
- assert self._sklearn_object is not None
721
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
722
- return labels
723
- else:
724
- raise NotImplementedError
745
+ self.fit(dataset)
746
+ assert self._sklearn_object is not None
747
+ return self._sklearn_object.labels_
748
+
749
+
750
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
751
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
752
+ """
753
+ Returns:
754
+ Transformed dataset.
755
+ """
756
+ self.fit(dataset)
757
+ assert self._sklearn_object is not None
758
+ return self._sklearn_object.embedding_
759
+
725
760
 
726
761
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
727
762
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RandomForestRegressor(BaseTransformer):
58
70
  r"""A random forest regressor
59
71
  For more details on this class, see [sklearn.ensemble.RandomForestRegressor]
@@ -270,7 +282,9 @@ class RandomForestRegressor(BaseTransformer):
270
282
  self.set_label_cols(label_cols)
271
283
  self.set_passthrough_cols(passthrough_cols)
272
284
  self.set_drop_input_cols(drop_input_cols)
273
- self.set_sample_weight_col(sample_weight_col)
285
+ self.set_sample_weight_col(sample_weight_col)
286
+ self._use_external_memory_version = False
287
+ self._batch_size = -1
274
288
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
275
289
 
276
290
  self._deps = list(deps)
@@ -362,11 +376,6 @@ class RandomForestRegressor(BaseTransformer):
362
376
  if isinstance(dataset, DataFrame):
363
377
  session = dataset._session
364
378
  assert session is not None # keep mypy happy
365
- # Validate that key package version in user workspace are supported in snowflake conda channel
366
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
367
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
368
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
369
-
370
379
  # Specify input columns so column pruning will be enforced
371
380
  selected_cols = self._get_active_columns()
372
381
  if len(selected_cols) > 0:
@@ -394,7 +403,9 @@ class RandomForestRegressor(BaseTransformer):
394
403
  label_cols=self.label_cols,
395
404
  sample_weight_col=self.sample_weight_col,
396
405
  autogenerated=self._autogenerated,
397
- subproject=_SUBPROJECT
406
+ subproject=_SUBPROJECT,
407
+ use_external_memory_version=self._use_external_memory_version,
408
+ batch_size=self._batch_size,
398
409
  )
399
410
  self._sklearn_object = model_trainer.train()
400
411
  self._is_fitted = True
@@ -665,6 +676,22 @@ class RandomForestRegressor(BaseTransformer):
665
676
  # each row containing a list of values.
666
677
  expected_dtype = "ARRAY"
667
678
 
679
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
680
+ if expected_dtype == "":
681
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
682
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
683
+ expected_dtype = "ARRAY"
684
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
685
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
686
+ expected_dtype = "ARRAY"
687
+ else:
688
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
689
+ # We can only infer the output types from the input types if the following two statemetns are true:
690
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
691
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
692
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
693
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
694
+
668
695
  output_df = self._batch_inference(
669
696
  dataset=dataset,
670
697
  inference_method="transform",
@@ -680,8 +707,8 @@ class RandomForestRegressor(BaseTransformer):
680
707
 
681
708
  return output_df
682
709
 
683
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
684
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
710
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
711
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
685
712
  """ Method not supported for this class.
686
713
 
687
714
 
@@ -694,13 +721,21 @@ class RandomForestRegressor(BaseTransformer):
694
721
  Returns:
695
722
  Predicted dataset.
696
723
  """
697
- if False:
698
- self.fit(dataset)
699
- assert self._sklearn_object is not None
700
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
701
- return labels
702
- else:
703
- raise NotImplementedError
724
+ self.fit(dataset)
725
+ assert self._sklearn_object is not None
726
+ return self._sklearn_object.labels_
727
+
728
+
729
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
730
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
731
+ """
732
+ Returns:
733
+ Transformed dataset.
734
+ """
735
+ self.fit(dataset)
736
+ assert self._sklearn_object is not None
737
+ return self._sklearn_object.embedding_
738
+
704
739
 
705
740
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
706
741
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class StackingRegressor(BaseTransformer):
58
70
  r"""Stack of estimators with a final regressor
59
71
  For more details on this class, see [sklearn.ensemble.StackingRegressor]
@@ -180,7 +192,9 @@ class StackingRegressor(BaseTransformer):
180
192
  self.set_label_cols(label_cols)
181
193
  self.set_passthrough_cols(passthrough_cols)
182
194
  self.set_drop_input_cols(drop_input_cols)
183
- self.set_sample_weight_col(sample_weight_col)
195
+ self.set_sample_weight_col(sample_weight_col)
196
+ self._use_external_memory_version = False
197
+ self._batch_size = -1
184
198
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
185
199
  deps = deps | gather_dependencies(estimators)
186
200
  deps = deps | gather_dependencies(final_estimator)
@@ -263,11 +277,6 @@ class StackingRegressor(BaseTransformer):
263
277
  if isinstance(dataset, DataFrame):
264
278
  session = dataset._session
265
279
  assert session is not None # keep mypy happy
266
- # Validate that key package version in user workspace are supported in snowflake conda channel
267
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
268
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
269
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
270
-
271
280
  # Specify input columns so column pruning will be enforced
272
281
  selected_cols = self._get_active_columns()
273
282
  if len(selected_cols) > 0:
@@ -295,7 +304,9 @@ class StackingRegressor(BaseTransformer):
295
304
  label_cols=self.label_cols,
296
305
  sample_weight_col=self.sample_weight_col,
297
306
  autogenerated=self._autogenerated,
298
- subproject=_SUBPROJECT
307
+ subproject=_SUBPROJECT,
308
+ use_external_memory_version=self._use_external_memory_version,
309
+ batch_size=self._batch_size,
299
310
  )
300
311
  self._sklearn_object = model_trainer.train()
301
312
  self._is_fitted = True
@@ -568,6 +579,22 @@ class StackingRegressor(BaseTransformer):
568
579
  # each row containing a list of values.
569
580
  expected_dtype = "ARRAY"
570
581
 
582
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
583
+ if expected_dtype == "":
584
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
585
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
586
+ expected_dtype = "ARRAY"
587
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
588
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
589
+ expected_dtype = "ARRAY"
590
+ else:
591
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
592
+ # We can only infer the output types from the input types if the following two statemetns are true:
593
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
594
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
595
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
596
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
597
+
571
598
  output_df = self._batch_inference(
572
599
  dataset=dataset,
573
600
  inference_method="transform",
@@ -583,8 +610,8 @@ class StackingRegressor(BaseTransformer):
583
610
 
584
611
  return output_df
585
612
 
586
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
587
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
613
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
614
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
588
615
  """ Method not supported for this class.
589
616
 
590
617
 
@@ -597,13 +624,21 @@ class StackingRegressor(BaseTransformer):
597
624
  Returns:
598
625
  Predicted dataset.
599
626
  """
600
- if False:
601
- self.fit(dataset)
602
- assert self._sklearn_object is not None
603
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
604
- return labels
605
- else:
606
- raise NotImplementedError
627
+ self.fit(dataset)
628
+ assert self._sklearn_object is not None
629
+ return self._sklearn_object.labels_
630
+
631
+
632
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
633
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
634
+ """
635
+ Returns:
636
+ Transformed dataset.
637
+ """
638
+ self.fit(dataset)
639
+ assert self._sklearn_object is not None
640
+ return self._sklearn_object.embedding_
641
+
607
642
 
608
643
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
609
644
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class VotingClassifier(BaseTransformer):
58
70
  r"""Soft Voting/Majority Rule classifier for unfitted estimators
59
71
  For more details on this class, see [sklearn.ensemble.VotingClassifier]
@@ -164,7 +176,9 @@ class VotingClassifier(BaseTransformer):
164
176
  self.set_label_cols(label_cols)
165
177
  self.set_passthrough_cols(passthrough_cols)
166
178
  self.set_drop_input_cols(drop_input_cols)
167
- self.set_sample_weight_col(sample_weight_col)
179
+ self.set_sample_weight_col(sample_weight_col)
180
+ self._use_external_memory_version = False
181
+ self._batch_size = -1
168
182
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
169
183
  deps = deps | gather_dependencies(estimators)
170
184
  self._deps = list(deps)
@@ -245,11 +259,6 @@ class VotingClassifier(BaseTransformer):
245
259
  if isinstance(dataset, DataFrame):
246
260
  session = dataset._session
247
261
  assert session is not None # keep mypy happy
248
- # Validate that key package version in user workspace are supported in snowflake conda channel
249
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
250
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
251
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
252
-
253
262
  # Specify input columns so column pruning will be enforced
254
263
  selected_cols = self._get_active_columns()
255
264
  if len(selected_cols) > 0:
@@ -277,7 +286,9 @@ class VotingClassifier(BaseTransformer):
277
286
  label_cols=self.label_cols,
278
287
  sample_weight_col=self.sample_weight_col,
279
288
  autogenerated=self._autogenerated,
280
- subproject=_SUBPROJECT
289
+ subproject=_SUBPROJECT,
290
+ use_external_memory_version=self._use_external_memory_version,
291
+ batch_size=self._batch_size,
281
292
  )
282
293
  self._sklearn_object = model_trainer.train()
283
294
  self._is_fitted = True
@@ -550,6 +561,22 @@ class VotingClassifier(BaseTransformer):
550
561
  # each row containing a list of values.
551
562
  expected_dtype = "ARRAY"
552
563
 
564
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
565
+ if expected_dtype == "":
566
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
567
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
568
+ expected_dtype = "ARRAY"
569
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
570
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
571
+ expected_dtype = "ARRAY"
572
+ else:
573
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
574
+ # We can only infer the output types from the input types if the following two statemetns are true:
575
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
576
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
577
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
578
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
579
+
553
580
  output_df = self._batch_inference(
554
581
  dataset=dataset,
555
582
  inference_method="transform",
@@ -565,8 +592,8 @@ class VotingClassifier(BaseTransformer):
565
592
 
566
593
  return output_df
567
594
 
568
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
569
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
595
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
596
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
570
597
  """ Method not supported for this class.
571
598
 
572
599
 
@@ -579,13 +606,21 @@ class VotingClassifier(BaseTransformer):
579
606
  Returns:
580
607
  Predicted dataset.
581
608
  """
582
- if False:
583
- self.fit(dataset)
584
- assert self._sklearn_object is not None
585
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
586
- return labels
587
- else:
588
- raise NotImplementedError
609
+ self.fit(dataset)
610
+ assert self._sklearn_object is not None
611
+ return self._sklearn_object.labels_
612
+
613
+
614
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
615
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
616
+ """
617
+ Returns:
618
+ Transformed dataset.
619
+ """
620
+ self.fit(dataset)
621
+ assert self._sklearn_object is not None
622
+ return self._sklearn_object.embedding_
623
+
589
624
 
590
625
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
591
626
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class VotingRegressor(BaseTransformer):
58
70
  r"""Prediction voting regressor for unfitted estimators
59
71
  For more details on this class, see [sklearn.ensemble.VotingRegressor]
@@ -148,7 +160,9 @@ class VotingRegressor(BaseTransformer):
148
160
  self.set_label_cols(label_cols)
149
161
  self.set_passthrough_cols(passthrough_cols)
150
162
  self.set_drop_input_cols(drop_input_cols)
151
- self.set_sample_weight_col(sample_weight_col)
163
+ self.set_sample_weight_col(sample_weight_col)
164
+ self._use_external_memory_version = False
165
+ self._batch_size = -1
152
166
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
153
167
  deps = deps | gather_dependencies(estimators)
154
168
  self._deps = list(deps)
@@ -227,11 +241,6 @@ class VotingRegressor(BaseTransformer):
227
241
  if isinstance(dataset, DataFrame):
228
242
  session = dataset._session
229
243
  assert session is not None # keep mypy happy
230
- # Validate that key package version in user workspace are supported in snowflake conda channel
231
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
232
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
233
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
234
-
235
244
  # Specify input columns so column pruning will be enforced
236
245
  selected_cols = self._get_active_columns()
237
246
  if len(selected_cols) > 0:
@@ -259,7 +268,9 @@ class VotingRegressor(BaseTransformer):
259
268
  label_cols=self.label_cols,
260
269
  sample_weight_col=self.sample_weight_col,
261
270
  autogenerated=self._autogenerated,
262
- subproject=_SUBPROJECT
271
+ subproject=_SUBPROJECT,
272
+ use_external_memory_version=self._use_external_memory_version,
273
+ batch_size=self._batch_size,
263
274
  )
264
275
  self._sklearn_object = model_trainer.train()
265
276
  self._is_fitted = True
@@ -532,6 +543,22 @@ class VotingRegressor(BaseTransformer):
532
543
  # each row containing a list of values.
533
544
  expected_dtype = "ARRAY"
534
545
 
546
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
547
+ if expected_dtype == "":
548
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
549
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
550
+ expected_dtype = "ARRAY"
551
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
552
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
553
+ expected_dtype = "ARRAY"
554
+ else:
555
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
556
+ # We can only infer the output types from the input types if the following two statemetns are true:
557
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
558
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
559
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
560
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
561
+
535
562
  output_df = self._batch_inference(
536
563
  dataset=dataset,
537
564
  inference_method="transform",
@@ -547,8 +574,8 @@ class VotingRegressor(BaseTransformer):
547
574
 
548
575
  return output_df
549
576
 
550
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
551
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
577
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
578
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
552
579
  """ Method not supported for this class.
553
580
 
554
581
 
@@ -561,13 +588,21 @@ class VotingRegressor(BaseTransformer):
561
588
  Returns:
562
589
  Predicted dataset.
563
590
  """
564
- if False:
565
- self.fit(dataset)
566
- assert self._sklearn_object is not None
567
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
568
- return labels
569
- else:
570
- raise NotImplementedError
591
+ self.fit(dataset)
592
+ assert self._sklearn_object is not None
593
+ return self._sklearn_object.labels_
594
+
595
+
596
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
597
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
598
+ """
599
+ Returns:
600
+ Transformed dataset.
601
+ """
602
+ self.fit(dataset)
603
+ assert self._sklearn_object is not None
604
+ return self._sklearn_object.embedding_
605
+
571
606
 
572
607
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
573
608
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.