snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GradientBoostingClassifier(BaseTransformer):
58
70
  r"""Gradient Boosting for classification
59
71
  For more details on this class, see [sklearn.ensemble.GradientBoostingClassifier]
@@ -304,7 +316,9 @@ class GradientBoostingClassifier(BaseTransformer):
304
316
  self.set_label_cols(label_cols)
305
317
  self.set_passthrough_cols(passthrough_cols)
306
318
  self.set_drop_input_cols(drop_input_cols)
307
- self.set_sample_weight_col(sample_weight_col)
319
+ self.set_sample_weight_col(sample_weight_col)
320
+ self._use_external_memory_version = False
321
+ self._batch_size = -1
308
322
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
309
323
 
310
324
  self._deps = list(deps)
@@ -399,11 +413,6 @@ class GradientBoostingClassifier(BaseTransformer):
399
413
  if isinstance(dataset, DataFrame):
400
414
  session = dataset._session
401
415
  assert session is not None # keep mypy happy
402
- # Validate that key package version in user workspace are supported in snowflake conda channel
403
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
404
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
405
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
406
-
407
416
  # Specify input columns so column pruning will be enforced
408
417
  selected_cols = self._get_active_columns()
409
418
  if len(selected_cols) > 0:
@@ -431,7 +440,9 @@ class GradientBoostingClassifier(BaseTransformer):
431
440
  label_cols=self.label_cols,
432
441
  sample_weight_col=self.sample_weight_col,
433
442
  autogenerated=self._autogenerated,
434
- subproject=_SUBPROJECT
443
+ subproject=_SUBPROJECT,
444
+ use_external_memory_version=self._use_external_memory_version,
445
+ batch_size=self._batch_size,
435
446
  )
436
447
  self._sklearn_object = model_trainer.train()
437
448
  self._is_fitted = True
@@ -702,6 +713,22 @@ class GradientBoostingClassifier(BaseTransformer):
702
713
  # each row containing a list of values.
703
714
  expected_dtype = "ARRAY"
704
715
 
716
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
717
+ if expected_dtype == "":
718
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
719
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
720
+ expected_dtype = "ARRAY"
721
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
722
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
723
+ expected_dtype = "ARRAY"
724
+ else:
725
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
726
+ # We can only infer the output types from the input types if the following two statemetns are true:
727
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
728
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
729
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
730
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
731
+
705
732
  output_df = self._batch_inference(
706
733
  dataset=dataset,
707
734
  inference_method="transform",
@@ -717,8 +744,8 @@ class GradientBoostingClassifier(BaseTransformer):
717
744
 
718
745
  return output_df
719
746
 
720
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
721
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
747
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
748
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
722
749
  """ Method not supported for this class.
723
750
 
724
751
 
@@ -731,13 +758,21 @@ class GradientBoostingClassifier(BaseTransformer):
731
758
  Returns:
732
759
  Predicted dataset.
733
760
  """
734
- if False:
735
- self.fit(dataset)
736
- assert self._sklearn_object is not None
737
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
738
- return labels
739
- else:
740
- raise NotImplementedError
761
+ self.fit(dataset)
762
+ assert self._sklearn_object is not None
763
+ return self._sklearn_object.labels_
764
+
765
+
766
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
767
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
768
+ """
769
+ Returns:
770
+ Transformed dataset.
771
+ """
772
+ self.fit(dataset)
773
+ assert self._sklearn_object is not None
774
+ return self._sklearn_object.embedding_
775
+
741
776
 
742
777
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
743
778
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GradientBoostingRegressor(BaseTransformer):
58
70
  r"""Gradient Boosting for regression
59
71
  For more details on this class, see [sklearn.ensemble.GradientBoostingRegressor]
@@ -312,7 +324,9 @@ class GradientBoostingRegressor(BaseTransformer):
312
324
  self.set_label_cols(label_cols)
313
325
  self.set_passthrough_cols(passthrough_cols)
314
326
  self.set_drop_input_cols(drop_input_cols)
315
- self.set_sample_weight_col(sample_weight_col)
327
+ self.set_sample_weight_col(sample_weight_col)
328
+ self._use_external_memory_version = False
329
+ self._batch_size = -1
316
330
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
317
331
 
318
332
  self._deps = list(deps)
@@ -408,11 +422,6 @@ class GradientBoostingRegressor(BaseTransformer):
408
422
  if isinstance(dataset, DataFrame):
409
423
  session = dataset._session
410
424
  assert session is not None # keep mypy happy
411
- # Validate that key package version in user workspace are supported in snowflake conda channel
412
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
413
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
414
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
415
-
416
425
  # Specify input columns so column pruning will be enforced
417
426
  selected_cols = self._get_active_columns()
418
427
  if len(selected_cols) > 0:
@@ -440,7 +449,9 @@ class GradientBoostingRegressor(BaseTransformer):
440
449
  label_cols=self.label_cols,
441
450
  sample_weight_col=self.sample_weight_col,
442
451
  autogenerated=self._autogenerated,
443
- subproject=_SUBPROJECT
452
+ subproject=_SUBPROJECT,
453
+ use_external_memory_version=self._use_external_memory_version,
454
+ batch_size=self._batch_size,
444
455
  )
445
456
  self._sklearn_object = model_trainer.train()
446
457
  self._is_fitted = True
@@ -711,6 +722,22 @@ class GradientBoostingRegressor(BaseTransformer):
711
722
  # each row containing a list of values.
712
723
  expected_dtype = "ARRAY"
713
724
 
725
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
726
+ if expected_dtype == "":
727
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
728
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
729
+ expected_dtype = "ARRAY"
730
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
731
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
732
+ expected_dtype = "ARRAY"
733
+ else:
734
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
735
+ # We can only infer the output types from the input types if the following two statemetns are true:
736
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
737
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
738
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
739
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
740
+
714
741
  output_df = self._batch_inference(
715
742
  dataset=dataset,
716
743
  inference_method="transform",
@@ -726,8 +753,8 @@ class GradientBoostingRegressor(BaseTransformer):
726
753
 
727
754
  return output_df
728
755
 
729
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
730
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
756
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
757
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
731
758
  """ Method not supported for this class.
732
759
 
733
760
 
@@ -740,13 +767,21 @@ class GradientBoostingRegressor(BaseTransformer):
740
767
  Returns:
741
768
  Predicted dataset.
742
769
  """
743
- if False:
744
- self.fit(dataset)
745
- assert self._sklearn_object is not None
746
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
747
- return labels
748
- else:
749
- raise NotImplementedError
770
+ self.fit(dataset)
771
+ assert self._sklearn_object is not None
772
+ return self._sklearn_object.labels_
773
+
774
+
775
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
776
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
777
+ """
778
+ Returns:
779
+ Transformed dataset.
780
+ """
781
+ self.fit(dataset)
782
+ assert self._sklearn_object is not None
783
+ return self._sklearn_object.embedding_
784
+
750
785
 
751
786
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
752
787
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class HistGradientBoostingClassifier(BaseTransformer):
58
70
  r"""Histogram-based Gradient Boosting Classification Tree
59
71
  For more details on this class, see [sklearn.ensemble.HistGradientBoostingClassifier]
@@ -285,7 +297,9 @@ class HistGradientBoostingClassifier(BaseTransformer):
285
297
  self.set_label_cols(label_cols)
286
298
  self.set_passthrough_cols(passthrough_cols)
287
299
  self.set_drop_input_cols(drop_input_cols)
288
- self.set_sample_weight_col(sample_weight_col)
300
+ self.set_sample_weight_col(sample_weight_col)
301
+ self._use_external_memory_version = False
302
+ self._batch_size = -1
289
303
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
290
304
 
291
305
  self._deps = list(deps)
@@ -380,11 +394,6 @@ class HistGradientBoostingClassifier(BaseTransformer):
380
394
  if isinstance(dataset, DataFrame):
381
395
  session = dataset._session
382
396
  assert session is not None # keep mypy happy
383
- # Validate that key package version in user workspace are supported in snowflake conda channel
384
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
385
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
386
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
387
-
388
397
  # Specify input columns so column pruning will be enforced
389
398
  selected_cols = self._get_active_columns()
390
399
  if len(selected_cols) > 0:
@@ -412,7 +421,9 @@ class HistGradientBoostingClassifier(BaseTransformer):
412
421
  label_cols=self.label_cols,
413
422
  sample_weight_col=self.sample_weight_col,
414
423
  autogenerated=self._autogenerated,
415
- subproject=_SUBPROJECT
424
+ subproject=_SUBPROJECT,
425
+ use_external_memory_version=self._use_external_memory_version,
426
+ batch_size=self._batch_size,
416
427
  )
417
428
  self._sklearn_object = model_trainer.train()
418
429
  self._is_fitted = True
@@ -683,6 +694,22 @@ class HistGradientBoostingClassifier(BaseTransformer):
683
694
  # each row containing a list of values.
684
695
  expected_dtype = "ARRAY"
685
696
 
697
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
698
+ if expected_dtype == "":
699
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
700
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
701
+ expected_dtype = "ARRAY"
702
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
703
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
704
+ expected_dtype = "ARRAY"
705
+ else:
706
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
707
+ # We can only infer the output types from the input types if the following two statemetns are true:
708
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
709
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
710
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
711
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
712
+
686
713
  output_df = self._batch_inference(
687
714
  dataset=dataset,
688
715
  inference_method="transform",
@@ -698,8 +725,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
698
725
 
699
726
  return output_df
700
727
 
701
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
702
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
728
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
729
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
703
730
  """ Method not supported for this class.
704
731
 
705
732
 
@@ -712,13 +739,21 @@ class HistGradientBoostingClassifier(BaseTransformer):
712
739
  Returns:
713
740
  Predicted dataset.
714
741
  """
715
- if False:
716
- self.fit(dataset)
717
- assert self._sklearn_object is not None
718
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
719
- return labels
720
- else:
721
- raise NotImplementedError
742
+ self.fit(dataset)
743
+ assert self._sklearn_object is not None
744
+ return self._sklearn_object.labels_
745
+
746
+
747
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
748
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
749
+ """
750
+ Returns:
751
+ Transformed dataset.
752
+ """
753
+ self.fit(dataset)
754
+ assert self._sklearn_object is not None
755
+ return self._sklearn_object.embedding_
756
+
722
757
 
723
758
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
724
759
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class HistGradientBoostingRegressor(BaseTransformer):
58
70
  r"""Histogram-based Gradient Boosting Regression Tree
59
71
  For more details on this class, see [sklearn.ensemble.HistGradientBoostingRegressor]
@@ -276,7 +288,9 @@ class HistGradientBoostingRegressor(BaseTransformer):
276
288
  self.set_label_cols(label_cols)
277
289
  self.set_passthrough_cols(passthrough_cols)
278
290
  self.set_drop_input_cols(drop_input_cols)
279
- self.set_sample_weight_col(sample_weight_col)
291
+ self.set_sample_weight_col(sample_weight_col)
292
+ self._use_external_memory_version = False
293
+ self._batch_size = -1
280
294
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
281
295
 
282
296
  self._deps = list(deps)
@@ -371,11 +385,6 @@ class HistGradientBoostingRegressor(BaseTransformer):
371
385
  if isinstance(dataset, DataFrame):
372
386
  session = dataset._session
373
387
  assert session is not None # keep mypy happy
374
- # Validate that key package version in user workspace are supported in snowflake conda channel
375
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
376
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
377
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
378
-
379
388
  # Specify input columns so column pruning will be enforced
380
389
  selected_cols = self._get_active_columns()
381
390
  if len(selected_cols) > 0:
@@ -403,7 +412,9 @@ class HistGradientBoostingRegressor(BaseTransformer):
403
412
  label_cols=self.label_cols,
404
413
  sample_weight_col=self.sample_weight_col,
405
414
  autogenerated=self._autogenerated,
406
- subproject=_SUBPROJECT
415
+ subproject=_SUBPROJECT,
416
+ use_external_memory_version=self._use_external_memory_version,
417
+ batch_size=self._batch_size,
407
418
  )
408
419
  self._sklearn_object = model_trainer.train()
409
420
  self._is_fitted = True
@@ -674,6 +685,22 @@ class HistGradientBoostingRegressor(BaseTransformer):
674
685
  # each row containing a list of values.
675
686
  expected_dtype = "ARRAY"
676
687
 
688
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
689
+ if expected_dtype == "":
690
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
691
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
692
+ expected_dtype = "ARRAY"
693
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
694
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
695
+ expected_dtype = "ARRAY"
696
+ else:
697
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
698
+ # We can only infer the output types from the input types if the following two statemetns are true:
699
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
700
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
701
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
702
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
703
+
677
704
  output_df = self._batch_inference(
678
705
  dataset=dataset,
679
706
  inference_method="transform",
@@ -689,8 +716,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
689
716
 
690
717
  return output_df
691
718
 
692
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
693
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
719
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
720
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
694
721
  """ Method not supported for this class.
695
722
 
696
723
 
@@ -703,13 +730,21 @@ class HistGradientBoostingRegressor(BaseTransformer):
703
730
  Returns:
704
731
  Predicted dataset.
705
732
  """
706
- if False:
707
- self.fit(dataset)
708
- assert self._sklearn_object is not None
709
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
710
- return labels
711
- else:
712
- raise NotImplementedError
733
+ self.fit(dataset)
734
+ assert self._sklearn_object is not None
735
+ return self._sklearn_object.labels_
736
+
737
+
738
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
739
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
740
+ """
741
+ Returns:
742
+ Transformed dataset.
743
+ """
744
+ self.fit(dataset)
745
+ assert self._sklearn_object is not None
746
+ return self._sklearn_object.embedding_
747
+
713
748
 
714
749
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
715
750
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class IsolationForest(BaseTransformer):
58
70
  r"""Isolation Forest Algorithm
59
71
  For more details on this class, see [sklearn.ensemble.IsolationForest]
@@ -187,7 +199,9 @@ class IsolationForest(BaseTransformer):
187
199
  self.set_label_cols(label_cols)
188
200
  self.set_passthrough_cols(passthrough_cols)
189
201
  self.set_drop_input_cols(drop_input_cols)
190
- self.set_sample_weight_col(sample_weight_col)
202
+ self.set_sample_weight_col(sample_weight_col)
203
+ self._use_external_memory_version = False
204
+ self._batch_size = -1
191
205
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
192
206
 
193
207
  self._deps = list(deps)
@@ -271,11 +285,6 @@ class IsolationForest(BaseTransformer):
271
285
  if isinstance(dataset, DataFrame):
272
286
  session = dataset._session
273
287
  assert session is not None # keep mypy happy
274
- # Validate that key package version in user workspace are supported in snowflake conda channel
275
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
276
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
277
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
278
-
279
288
  # Specify input columns so column pruning will be enforced
280
289
  selected_cols = self._get_active_columns()
281
290
  if len(selected_cols) > 0:
@@ -303,7 +312,9 @@ class IsolationForest(BaseTransformer):
303
312
  label_cols=self.label_cols,
304
313
  sample_weight_col=self.sample_weight_col,
305
314
  autogenerated=self._autogenerated,
306
- subproject=_SUBPROJECT
315
+ subproject=_SUBPROJECT,
316
+ use_external_memory_version=self._use_external_memory_version,
317
+ batch_size=self._batch_size,
307
318
  )
308
319
  self._sklearn_object = model_trainer.train()
309
320
  self._is_fitted = True
@@ -574,6 +585,22 @@ class IsolationForest(BaseTransformer):
574
585
  # each row containing a list of values.
575
586
  expected_dtype = "ARRAY"
576
587
 
588
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
589
+ if expected_dtype == "":
590
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
591
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
592
+ expected_dtype = "ARRAY"
593
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
594
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
595
+ expected_dtype = "ARRAY"
596
+ else:
597
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
598
+ # We can only infer the output types from the input types if the following two statemetns are true:
599
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
600
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
601
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
602
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
603
+
577
604
  output_df = self._batch_inference(
578
605
  dataset=dataset,
579
606
  inference_method="transform",
@@ -589,8 +616,8 @@ class IsolationForest(BaseTransformer):
589
616
 
590
617
  return output_df
591
618
 
592
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
593
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
619
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
620
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
594
621
  """ Perform fit on X and returns labels for X
595
622
  For more details on this function, see [sklearn.ensemble.IsolationForest.fit_predict]
596
623
  (https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.IsolationForest.html#sklearn.ensemble.IsolationForest.fit_predict)
@@ -605,13 +632,21 @@ class IsolationForest(BaseTransformer):
605
632
  Returns:
606
633
  Predicted dataset.
607
634
  """
608
- if False:
609
- self.fit(dataset)
610
- assert self._sklearn_object is not None
611
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
612
- return labels
613
- else:
614
- raise NotImplementedError
635
+ self.fit(dataset)
636
+ assert self._sklearn_object is not None
637
+ return self._sklearn_object.labels_
638
+
639
+
640
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
641
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
642
+ """
643
+ Returns:
644
+ Transformed dataset.
645
+ """
646
+ self.fit(dataset)
647
+ assert self._sklearn_object is not None
648
+ return self._sklearn_object.embedding_
649
+
615
650
 
616
651
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
617
652
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.