snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SVR(BaseTransformer):
58
70
  r"""Epsilon-Support Vector Regression
59
71
  For more details on this class, see [sklearn.svm.SVR]
@@ -185,7 +197,9 @@ class SVR(BaseTransformer):
185
197
  self.set_label_cols(label_cols)
186
198
  self.set_passthrough_cols(passthrough_cols)
187
199
  self.set_drop_input_cols(drop_input_cols)
188
- self.set_sample_weight_col(sample_weight_col)
200
+ self.set_sample_weight_col(sample_weight_col)
201
+ self._use_external_memory_version = False
202
+ self._batch_size = -1
189
203
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
190
204
 
191
205
  self._deps = list(deps)
@@ -271,11 +285,6 @@ class SVR(BaseTransformer):
271
285
  if isinstance(dataset, DataFrame):
272
286
  session = dataset._session
273
287
  assert session is not None # keep mypy happy
274
- # Validate that key package version in user workspace are supported in snowflake conda channel
275
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
276
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
277
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
278
-
279
288
  # Specify input columns so column pruning will be enforced
280
289
  selected_cols = self._get_active_columns()
281
290
  if len(selected_cols) > 0:
@@ -303,7 +312,9 @@ class SVR(BaseTransformer):
303
312
  label_cols=self.label_cols,
304
313
  sample_weight_col=self.sample_weight_col,
305
314
  autogenerated=self._autogenerated,
306
- subproject=_SUBPROJECT
315
+ subproject=_SUBPROJECT,
316
+ use_external_memory_version=self._use_external_memory_version,
317
+ batch_size=self._batch_size,
307
318
  )
308
319
  self._sklearn_object = model_trainer.train()
309
320
  self._is_fitted = True
@@ -574,6 +585,22 @@ class SVR(BaseTransformer):
574
585
  # each row containing a list of values.
575
586
  expected_dtype = "ARRAY"
576
587
 
588
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
589
+ if expected_dtype == "":
590
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
591
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
592
+ expected_dtype = "ARRAY"
593
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
594
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
595
+ expected_dtype = "ARRAY"
596
+ else:
597
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
598
+ # We can only infer the output types from the input types if the following two statemetns are true:
599
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
600
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
601
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
602
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
603
+
577
604
  output_df = self._batch_inference(
578
605
  dataset=dataset,
579
606
  inference_method="transform",
@@ -589,8 +616,8 @@ class SVR(BaseTransformer):
589
616
 
590
617
  return output_df
591
618
 
592
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
593
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
619
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
620
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
594
621
  """ Method not supported for this class.
595
622
 
596
623
 
@@ -603,13 +630,21 @@ class SVR(BaseTransformer):
603
630
  Returns:
604
631
  Predicted dataset.
605
632
  """
606
- if False:
607
- self.fit(dataset)
608
- assert self._sklearn_object is not None
609
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
610
- return labels
611
- else:
612
- raise NotImplementedError
633
+ self.fit(dataset)
634
+ assert self._sklearn_object is not None
635
+ return self._sklearn_object.labels_
636
+
637
+
638
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
639
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
640
+ """
641
+ Returns:
642
+ Transformed dataset.
643
+ """
644
+ self.fit(dataset)
645
+ assert self._sklearn_object is not None
646
+ return self._sklearn_object.embedding_
647
+
613
648
 
614
649
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
615
650
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class DecisionTreeClassifier(BaseTransformer):
58
70
  r"""A decision tree classifier
59
71
  For more details on this class, see [sklearn.tree.DecisionTreeClassifier]
@@ -251,7 +263,9 @@ class DecisionTreeClassifier(BaseTransformer):
251
263
  self.set_label_cols(label_cols)
252
264
  self.set_passthrough_cols(passthrough_cols)
253
265
  self.set_drop_input_cols(drop_input_cols)
254
- self.set_sample_weight_col(sample_weight_col)
266
+ self.set_sample_weight_col(sample_weight_col)
267
+ self._use_external_memory_version = False
268
+ self._batch_size = -1
255
269
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
256
270
 
257
271
  self._deps = list(deps)
@@ -338,11 +352,6 @@ class DecisionTreeClassifier(BaseTransformer):
338
352
  if isinstance(dataset, DataFrame):
339
353
  session = dataset._session
340
354
  assert session is not None # keep mypy happy
341
- # Validate that key package version in user workspace are supported in snowflake conda channel
342
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
343
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
344
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
345
-
346
355
  # Specify input columns so column pruning will be enforced
347
356
  selected_cols = self._get_active_columns()
348
357
  if len(selected_cols) > 0:
@@ -370,7 +379,9 @@ class DecisionTreeClassifier(BaseTransformer):
370
379
  label_cols=self.label_cols,
371
380
  sample_weight_col=self.sample_weight_col,
372
381
  autogenerated=self._autogenerated,
373
- subproject=_SUBPROJECT
382
+ subproject=_SUBPROJECT,
383
+ use_external_memory_version=self._use_external_memory_version,
384
+ batch_size=self._batch_size,
374
385
  )
375
386
  self._sklearn_object = model_trainer.train()
376
387
  self._is_fitted = True
@@ -641,6 +652,22 @@ class DecisionTreeClassifier(BaseTransformer):
641
652
  # each row containing a list of values.
642
653
  expected_dtype = "ARRAY"
643
654
 
655
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
656
+ if expected_dtype == "":
657
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
658
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
659
+ expected_dtype = "ARRAY"
660
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
661
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
662
+ expected_dtype = "ARRAY"
663
+ else:
664
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
665
+ # We can only infer the output types from the input types if the following two statemetns are true:
666
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
667
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
668
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
669
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
670
+
644
671
  output_df = self._batch_inference(
645
672
  dataset=dataset,
646
673
  inference_method="transform",
@@ -656,8 +683,8 @@ class DecisionTreeClassifier(BaseTransformer):
656
683
 
657
684
  return output_df
658
685
 
659
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
660
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
686
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
687
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
661
688
  """ Method not supported for this class.
662
689
 
663
690
 
@@ -670,13 +697,21 @@ class DecisionTreeClassifier(BaseTransformer):
670
697
  Returns:
671
698
  Predicted dataset.
672
699
  """
673
- if False:
674
- self.fit(dataset)
675
- assert self._sklearn_object is not None
676
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
677
- return labels
678
- else:
679
- raise NotImplementedError
700
+ self.fit(dataset)
701
+ assert self._sklearn_object is not None
702
+ return self._sklearn_object.labels_
703
+
704
+
705
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
706
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
707
+ """
708
+ Returns:
709
+ Transformed dataset.
710
+ """
711
+ self.fit(dataset)
712
+ assert self._sklearn_object is not None
713
+ return self._sklearn_object.embedding_
714
+
680
715
 
681
716
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
682
717
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class DecisionTreeRegressor(BaseTransformer):
58
70
  r"""A decision tree regressor
59
71
  For more details on this class, see [sklearn.tree.DecisionTreeRegressor]
@@ -234,7 +246,9 @@ class DecisionTreeRegressor(BaseTransformer):
234
246
  self.set_label_cols(label_cols)
235
247
  self.set_passthrough_cols(passthrough_cols)
236
248
  self.set_drop_input_cols(drop_input_cols)
237
- self.set_sample_weight_col(sample_weight_col)
249
+ self.set_sample_weight_col(sample_weight_col)
250
+ self._use_external_memory_version = False
251
+ self._batch_size = -1
238
252
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
239
253
 
240
254
  self._deps = list(deps)
@@ -320,11 +334,6 @@ class DecisionTreeRegressor(BaseTransformer):
320
334
  if isinstance(dataset, DataFrame):
321
335
  session = dataset._session
322
336
  assert session is not None # keep mypy happy
323
- # Validate that key package version in user workspace are supported in snowflake conda channel
324
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
325
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
326
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
327
-
328
337
  # Specify input columns so column pruning will be enforced
329
338
  selected_cols = self._get_active_columns()
330
339
  if len(selected_cols) > 0:
@@ -352,7 +361,9 @@ class DecisionTreeRegressor(BaseTransformer):
352
361
  label_cols=self.label_cols,
353
362
  sample_weight_col=self.sample_weight_col,
354
363
  autogenerated=self._autogenerated,
355
- subproject=_SUBPROJECT
364
+ subproject=_SUBPROJECT,
365
+ use_external_memory_version=self._use_external_memory_version,
366
+ batch_size=self._batch_size,
356
367
  )
357
368
  self._sklearn_object = model_trainer.train()
358
369
  self._is_fitted = True
@@ -623,6 +634,22 @@ class DecisionTreeRegressor(BaseTransformer):
623
634
  # each row containing a list of values.
624
635
  expected_dtype = "ARRAY"
625
636
 
637
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
638
+ if expected_dtype == "":
639
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
640
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
641
+ expected_dtype = "ARRAY"
642
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
643
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
644
+ expected_dtype = "ARRAY"
645
+ else:
646
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
647
+ # We can only infer the output types from the input types if the following two statemetns are true:
648
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
649
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
650
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
651
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
652
+
626
653
  output_df = self._batch_inference(
627
654
  dataset=dataset,
628
655
  inference_method="transform",
@@ -638,8 +665,8 @@ class DecisionTreeRegressor(BaseTransformer):
638
665
 
639
666
  return output_df
640
667
 
641
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
642
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
668
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
669
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
643
670
  """ Method not supported for this class.
644
671
 
645
672
 
@@ -652,13 +679,21 @@ class DecisionTreeRegressor(BaseTransformer):
652
679
  Returns:
653
680
  Predicted dataset.
654
681
  """
655
- if False:
656
- self.fit(dataset)
657
- assert self._sklearn_object is not None
658
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
659
- return labels
660
- else:
661
- raise NotImplementedError
682
+ self.fit(dataset)
683
+ assert self._sklearn_object is not None
684
+ return self._sklearn_object.labels_
685
+
686
+
687
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
688
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
689
+ """
690
+ Returns:
691
+ Transformed dataset.
692
+ """
693
+ self.fit(dataset)
694
+ assert self._sklearn_object is not None
695
+ return self._sklearn_object.embedding_
696
+
662
697
 
663
698
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
664
699
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ExtraTreeClassifier(BaseTransformer):
58
70
  r"""An extremely randomized tree classifier
59
71
  For more details on this class, see [sklearn.tree.ExtraTreeClassifier]
@@ -243,7 +255,9 @@ class ExtraTreeClassifier(BaseTransformer):
243
255
  self.set_label_cols(label_cols)
244
256
  self.set_passthrough_cols(passthrough_cols)
245
257
  self.set_drop_input_cols(drop_input_cols)
246
- self.set_sample_weight_col(sample_weight_col)
258
+ self.set_sample_weight_col(sample_weight_col)
259
+ self._use_external_memory_version = False
260
+ self._batch_size = -1
247
261
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
248
262
 
249
263
  self._deps = list(deps)
@@ -330,11 +344,6 @@ class ExtraTreeClassifier(BaseTransformer):
330
344
  if isinstance(dataset, DataFrame):
331
345
  session = dataset._session
332
346
  assert session is not None # keep mypy happy
333
- # Validate that key package version in user workspace are supported in snowflake conda channel
334
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
335
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
336
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
337
-
338
347
  # Specify input columns so column pruning will be enforced
339
348
  selected_cols = self._get_active_columns()
340
349
  if len(selected_cols) > 0:
@@ -362,7 +371,9 @@ class ExtraTreeClassifier(BaseTransformer):
362
371
  label_cols=self.label_cols,
363
372
  sample_weight_col=self.sample_weight_col,
364
373
  autogenerated=self._autogenerated,
365
- subproject=_SUBPROJECT
374
+ subproject=_SUBPROJECT,
375
+ use_external_memory_version=self._use_external_memory_version,
376
+ batch_size=self._batch_size,
366
377
  )
367
378
  self._sklearn_object = model_trainer.train()
368
379
  self._is_fitted = True
@@ -633,6 +644,22 @@ class ExtraTreeClassifier(BaseTransformer):
633
644
  # each row containing a list of values.
634
645
  expected_dtype = "ARRAY"
635
646
 
647
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
648
+ if expected_dtype == "":
649
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
650
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
651
+ expected_dtype = "ARRAY"
652
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
653
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
654
+ expected_dtype = "ARRAY"
655
+ else:
656
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
657
+ # We can only infer the output types from the input types if the following two statemetns are true:
658
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
659
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
660
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
661
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
662
+
636
663
  output_df = self._batch_inference(
637
664
  dataset=dataset,
638
665
  inference_method="transform",
@@ -648,8 +675,8 @@ class ExtraTreeClassifier(BaseTransformer):
648
675
 
649
676
  return output_df
650
677
 
651
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
652
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
678
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
679
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
653
680
  """ Method not supported for this class.
654
681
 
655
682
 
@@ -662,13 +689,21 @@ class ExtraTreeClassifier(BaseTransformer):
662
689
  Returns:
663
690
  Predicted dataset.
664
691
  """
665
- if False:
666
- self.fit(dataset)
667
- assert self._sklearn_object is not None
668
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
669
- return labels
670
- else:
671
- raise NotImplementedError
692
+ self.fit(dataset)
693
+ assert self._sklearn_object is not None
694
+ return self._sklearn_object.labels_
695
+
696
+
697
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
698
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
699
+ """
700
+ Returns:
701
+ Transformed dataset.
702
+ """
703
+ self.fit(dataset)
704
+ assert self._sklearn_object is not None
705
+ return self._sklearn_object.embedding_
706
+
672
707
 
673
708
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
674
709
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ExtraTreeRegressor(BaseTransformer):
58
70
  r"""An extremely randomized tree regressor
59
71
  For more details on this class, see [sklearn.tree.ExtraTreeRegressor]
@@ -226,7 +238,9 @@ class ExtraTreeRegressor(BaseTransformer):
226
238
  self.set_label_cols(label_cols)
227
239
  self.set_passthrough_cols(passthrough_cols)
228
240
  self.set_drop_input_cols(drop_input_cols)
229
- self.set_sample_weight_col(sample_weight_col)
241
+ self.set_sample_weight_col(sample_weight_col)
242
+ self._use_external_memory_version = False
243
+ self._batch_size = -1
230
244
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
231
245
 
232
246
  self._deps = list(deps)
@@ -312,11 +326,6 @@ class ExtraTreeRegressor(BaseTransformer):
312
326
  if isinstance(dataset, DataFrame):
313
327
  session = dataset._session
314
328
  assert session is not None # keep mypy happy
315
- # Validate that key package version in user workspace are supported in snowflake conda channel
316
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
317
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
318
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
319
-
320
329
  # Specify input columns so column pruning will be enforced
321
330
  selected_cols = self._get_active_columns()
322
331
  if len(selected_cols) > 0:
@@ -344,7 +353,9 @@ class ExtraTreeRegressor(BaseTransformer):
344
353
  label_cols=self.label_cols,
345
354
  sample_weight_col=self.sample_weight_col,
346
355
  autogenerated=self._autogenerated,
347
- subproject=_SUBPROJECT
356
+ subproject=_SUBPROJECT,
357
+ use_external_memory_version=self._use_external_memory_version,
358
+ batch_size=self._batch_size,
348
359
  )
349
360
  self._sklearn_object = model_trainer.train()
350
361
  self._is_fitted = True
@@ -615,6 +626,22 @@ class ExtraTreeRegressor(BaseTransformer):
615
626
  # each row containing a list of values.
616
627
  expected_dtype = "ARRAY"
617
628
 
629
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
630
+ if expected_dtype == "":
631
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
632
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
633
+ expected_dtype = "ARRAY"
634
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
635
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
636
+ expected_dtype = "ARRAY"
637
+ else:
638
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
639
+ # We can only infer the output types from the input types if the following two statemetns are true:
640
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
641
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
642
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
643
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
644
+
618
645
  output_df = self._batch_inference(
619
646
  dataset=dataset,
620
647
  inference_method="transform",
@@ -630,8 +657,8 @@ class ExtraTreeRegressor(BaseTransformer):
630
657
 
631
658
  return output_df
632
659
 
633
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
634
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
660
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
661
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
635
662
  """ Method not supported for this class.
636
663
 
637
664
 
@@ -644,13 +671,21 @@ class ExtraTreeRegressor(BaseTransformer):
644
671
  Returns:
645
672
  Predicted dataset.
646
673
  """
647
- if False:
648
- self.fit(dataset)
649
- assert self._sklearn_object is not None
650
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
651
- return labels
652
- else:
653
- raise NotImplementedError
674
+ self.fit(dataset)
675
+ assert self._sklearn_object is not None
676
+ return self._sklearn_object.labels_
677
+
678
+
679
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
680
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
681
+ """
682
+ Returns:
683
+ Transformed dataset.
684
+ """
685
+ self.fit(dataset)
686
+ assert self._sklearn_object is not None
687
+ return self._sklearn_object.embedding_
688
+
654
689
 
655
690
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
656
691
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.