snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SpectralClustering(BaseTransformer):
58
70
  r"""Apply clustering to a projection of the normalized Laplacian
59
71
  For more details on this class, see [sklearn.cluster.SpectralClustering]
@@ -237,7 +249,9 @@ class SpectralClustering(BaseTransformer):
237
249
  self.set_label_cols(label_cols)
238
250
  self.set_passthrough_cols(passthrough_cols)
239
251
  self.set_drop_input_cols(drop_input_cols)
240
- self.set_sample_weight_col(sample_weight_col)
252
+ self.set_sample_weight_col(sample_weight_col)
253
+ self._use_external_memory_version = False
254
+ self._batch_size = -1
241
255
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
242
256
 
243
257
  self._deps = list(deps)
@@ -327,11 +341,6 @@ class SpectralClustering(BaseTransformer):
327
341
  if isinstance(dataset, DataFrame):
328
342
  session = dataset._session
329
343
  assert session is not None # keep mypy happy
330
- # Validate that key package version in user workspace are supported in snowflake conda channel
331
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
332
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
333
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
334
-
335
344
  # Specify input columns so column pruning will be enforced
336
345
  selected_cols = self._get_active_columns()
337
346
  if len(selected_cols) > 0:
@@ -359,7 +368,9 @@ class SpectralClustering(BaseTransformer):
359
368
  label_cols=self.label_cols,
360
369
  sample_weight_col=self.sample_weight_col,
361
370
  autogenerated=self._autogenerated,
362
- subproject=_SUBPROJECT
371
+ subproject=_SUBPROJECT,
372
+ use_external_memory_version=self._use_external_memory_version,
373
+ batch_size=self._batch_size,
363
374
  )
364
375
  self._sklearn_object = model_trainer.train()
365
376
  self._is_fitted = True
@@ -628,6 +639,22 @@ class SpectralClustering(BaseTransformer):
628
639
  # each row containing a list of values.
629
640
  expected_dtype = "ARRAY"
630
641
 
642
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
643
+ if expected_dtype == "":
644
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
645
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
646
+ expected_dtype = "ARRAY"
647
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
648
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
649
+ expected_dtype = "ARRAY"
650
+ else:
651
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
652
+ # We can only infer the output types from the input types if the following two statemetns are true:
653
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
654
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
655
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
656
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
657
+
631
658
  output_df = self._batch_inference(
632
659
  dataset=dataset,
633
660
  inference_method="transform",
@@ -643,8 +670,8 @@ class SpectralClustering(BaseTransformer):
643
670
 
644
671
  return output_df
645
672
 
646
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
647
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
673
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
674
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
648
675
  """ Perform spectral clustering on `X` and return cluster labels
649
676
  For more details on this function, see [sklearn.cluster.SpectralClustering.fit_predict]
650
677
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.SpectralClustering.html#sklearn.cluster.SpectralClustering.fit_predict)
@@ -659,13 +686,21 @@ class SpectralClustering(BaseTransformer):
659
686
  Returns:
660
687
  Predicted dataset.
661
688
  """
662
- if True:
663
- self.fit(dataset)
664
- assert self._sklearn_object is not None
665
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
666
- return labels
667
- else:
668
- raise NotImplementedError
689
+ self.fit(dataset)
690
+ assert self._sklearn_object is not None
691
+ return self._sklearn_object.labels_
692
+
693
+
694
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
695
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
696
+ """
697
+ Returns:
698
+ Transformed dataset.
699
+ """
700
+ self.fit(dataset)
701
+ assert self._sklearn_object is not None
702
+ return self._sklearn_object.embedding_
703
+
669
704
 
670
705
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
671
706
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SpectralCoclustering(BaseTransformer):
58
70
  r"""Spectral Co-Clustering algorithm (Dhillon, 2001)
59
71
  For more details on this class, see [sklearn.cluster.SpectralCoclustering]
@@ -166,7 +178,9 @@ class SpectralCoclustering(BaseTransformer):
166
178
  self.set_label_cols(label_cols)
167
179
  self.set_passthrough_cols(passthrough_cols)
168
180
  self.set_drop_input_cols(drop_input_cols)
169
- self.set_sample_weight_col(sample_weight_col)
181
+ self.set_sample_weight_col(sample_weight_col)
182
+ self._use_external_memory_version = False
183
+ self._batch_size = -1
170
184
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
171
185
 
172
186
  self._deps = list(deps)
@@ -248,11 +262,6 @@ class SpectralCoclustering(BaseTransformer):
248
262
  if isinstance(dataset, DataFrame):
249
263
  session = dataset._session
250
264
  assert session is not None # keep mypy happy
251
- # Validate that key package version in user workspace are supported in snowflake conda channel
252
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
253
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
254
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
255
-
256
265
  # Specify input columns so column pruning will be enforced
257
266
  selected_cols = self._get_active_columns()
258
267
  if len(selected_cols) > 0:
@@ -280,7 +289,9 @@ class SpectralCoclustering(BaseTransformer):
280
289
  label_cols=self.label_cols,
281
290
  sample_weight_col=self.sample_weight_col,
282
291
  autogenerated=self._autogenerated,
283
- subproject=_SUBPROJECT
292
+ subproject=_SUBPROJECT,
293
+ use_external_memory_version=self._use_external_memory_version,
294
+ batch_size=self._batch_size,
284
295
  )
285
296
  self._sklearn_object = model_trainer.train()
286
297
  self._is_fitted = True
@@ -549,6 +560,22 @@ class SpectralCoclustering(BaseTransformer):
549
560
  # each row containing a list of values.
550
561
  expected_dtype = "ARRAY"
551
562
 
563
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
564
+ if expected_dtype == "":
565
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
566
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
567
+ expected_dtype = "ARRAY"
568
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
569
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
570
+ expected_dtype = "ARRAY"
571
+ else:
572
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
573
+ # We can only infer the output types from the input types if the following two statemetns are true:
574
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
575
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
576
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
577
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
578
+
552
579
  output_df = self._batch_inference(
553
580
  dataset=dataset,
554
581
  inference_method="transform",
@@ -564,8 +591,8 @@ class SpectralCoclustering(BaseTransformer):
564
591
 
565
592
  return output_df
566
593
 
567
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
568
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
594
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
595
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
569
596
  """ Method not supported for this class.
570
597
 
571
598
 
@@ -578,13 +605,21 @@ class SpectralCoclustering(BaseTransformer):
578
605
  Returns:
579
606
  Predicted dataset.
580
607
  """
581
- if False:
582
- self.fit(dataset)
583
- assert self._sklearn_object is not None
584
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
585
- return labels
586
- else:
587
- raise NotImplementedError
608
+ self.fit(dataset)
609
+ assert self._sklearn_object is not None
610
+ return self._sklearn_object.labels_
611
+
612
+
613
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
614
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
615
+ """
616
+ Returns:
617
+ Transformed dataset.
618
+ """
619
+ self.fit(dataset)
620
+ assert self._sklearn_object is not None
621
+ return self._sklearn_object.embedding_
622
+
588
623
 
589
624
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
590
625
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ColumnTransformer(BaseTransformer):
58
70
  r"""Applies transformers to columns of an array or pandas DataFrame
59
71
  For more details on this class, see [sklearn.compose.ColumnTransformer]
@@ -196,7 +208,9 @@ class ColumnTransformer(BaseTransformer):
196
208
  self.set_label_cols(label_cols)
197
209
  self.set_passthrough_cols(passthrough_cols)
198
210
  self.set_drop_input_cols(drop_input_cols)
199
- self.set_sample_weight_col(sample_weight_col)
211
+ self.set_sample_weight_col(sample_weight_col)
212
+ self._use_external_memory_version = False
213
+ self._batch_size = -1
200
214
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
201
215
  deps = deps | gather_dependencies(transformers)
202
216
  self._deps = list(deps)
@@ -278,11 +292,6 @@ class ColumnTransformer(BaseTransformer):
278
292
  if isinstance(dataset, DataFrame):
279
293
  session = dataset._session
280
294
  assert session is not None # keep mypy happy
281
- # Validate that key package version in user workspace are supported in snowflake conda channel
282
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
283
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
284
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
285
-
286
295
  # Specify input columns so column pruning will be enforced
287
296
  selected_cols = self._get_active_columns()
288
297
  if len(selected_cols) > 0:
@@ -310,7 +319,9 @@ class ColumnTransformer(BaseTransformer):
310
319
  label_cols=self.label_cols,
311
320
  sample_weight_col=self.sample_weight_col,
312
321
  autogenerated=self._autogenerated,
313
- subproject=_SUBPROJECT
322
+ subproject=_SUBPROJECT,
323
+ use_external_memory_version=self._use_external_memory_version,
324
+ batch_size=self._batch_size,
314
325
  )
315
326
  self._sklearn_object = model_trainer.train()
316
327
  self._is_fitted = True
@@ -581,6 +592,22 @@ class ColumnTransformer(BaseTransformer):
581
592
  # each row containing a list of values.
582
593
  expected_dtype = "ARRAY"
583
594
 
595
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
596
+ if expected_dtype == "":
597
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
598
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
599
+ expected_dtype = "ARRAY"
600
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
601
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
602
+ expected_dtype = "ARRAY"
603
+ else:
604
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
605
+ # We can only infer the output types from the input types if the following two statemetns are true:
606
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
607
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
608
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
609
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
610
+
584
611
  output_df = self._batch_inference(
585
612
  dataset=dataset,
586
613
  inference_method="transform",
@@ -596,8 +623,8 @@ class ColumnTransformer(BaseTransformer):
596
623
 
597
624
  return output_df
598
625
 
599
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
600
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
626
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
627
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
601
628
  """ Method not supported for this class.
602
629
 
603
630
 
@@ -610,13 +637,21 @@ class ColumnTransformer(BaseTransformer):
610
637
  Returns:
611
638
  Predicted dataset.
612
639
  """
613
- if False:
614
- self.fit(dataset)
615
- assert self._sklearn_object is not None
616
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
617
- return labels
618
- else:
619
- raise NotImplementedError
640
+ self.fit(dataset)
641
+ assert self._sklearn_object is not None
642
+ return self._sklearn_object.labels_
643
+
644
+
645
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
646
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
647
+ """
648
+ Returns:
649
+ Transformed dataset.
650
+ """
651
+ self.fit(dataset)
652
+ assert self._sklearn_object is not None
653
+ return self._sklearn_object.embedding_
654
+
620
655
 
621
656
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
622
657
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class TransformedTargetRegressor(BaseTransformer):
58
70
  r"""Meta-estimator to regress on a transformed target
59
71
  For more details on this class, see [sklearn.compose.TransformedTargetRegressor]
@@ -159,7 +171,9 @@ class TransformedTargetRegressor(BaseTransformer):
159
171
  self.set_label_cols(label_cols)
160
172
  self.set_passthrough_cols(passthrough_cols)
161
173
  self.set_drop_input_cols(drop_input_cols)
162
- self.set_sample_weight_col(sample_weight_col)
174
+ self.set_sample_weight_col(sample_weight_col)
175
+ self._use_external_memory_version = False
176
+ self._batch_size = -1
163
177
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
164
178
 
165
179
  self._deps = list(deps)
@@ -239,11 +253,6 @@ class TransformedTargetRegressor(BaseTransformer):
239
253
  if isinstance(dataset, DataFrame):
240
254
  session = dataset._session
241
255
  assert session is not None # keep mypy happy
242
- # Validate that key package version in user workspace are supported in snowflake conda channel
243
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
244
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
245
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
246
-
247
256
  # Specify input columns so column pruning will be enforced
248
257
  selected_cols = self._get_active_columns()
249
258
  if len(selected_cols) > 0:
@@ -271,7 +280,9 @@ class TransformedTargetRegressor(BaseTransformer):
271
280
  label_cols=self.label_cols,
272
281
  sample_weight_col=self.sample_weight_col,
273
282
  autogenerated=self._autogenerated,
274
- subproject=_SUBPROJECT
283
+ subproject=_SUBPROJECT,
284
+ use_external_memory_version=self._use_external_memory_version,
285
+ batch_size=self._batch_size,
275
286
  )
276
287
  self._sklearn_object = model_trainer.train()
277
288
  self._is_fitted = True
@@ -542,6 +553,22 @@ class TransformedTargetRegressor(BaseTransformer):
542
553
  # each row containing a list of values.
543
554
  expected_dtype = "ARRAY"
544
555
 
556
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
557
+ if expected_dtype == "":
558
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
559
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
560
+ expected_dtype = "ARRAY"
561
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
562
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
563
+ expected_dtype = "ARRAY"
564
+ else:
565
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
566
+ # We can only infer the output types from the input types if the following two statemetns are true:
567
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
568
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
569
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
570
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
571
+
545
572
  output_df = self._batch_inference(
546
573
  dataset=dataset,
547
574
  inference_method="transform",
@@ -557,8 +584,8 @@ class TransformedTargetRegressor(BaseTransformer):
557
584
 
558
585
  return output_df
559
586
 
560
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
561
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
587
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
588
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
562
589
  """ Method not supported for this class.
563
590
 
564
591
 
@@ -571,13 +598,21 @@ class TransformedTargetRegressor(BaseTransformer):
571
598
  Returns:
572
599
  Predicted dataset.
573
600
  """
574
- if False:
575
- self.fit(dataset)
576
- assert self._sklearn_object is not None
577
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
578
- return labels
579
- else:
580
- raise NotImplementedError
601
+ self.fit(dataset)
602
+ assert self._sklearn_object is not None
603
+ return self._sklearn_object.labels_
604
+
605
+
606
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
607
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
608
+ """
609
+ Returns:
610
+ Transformed dataset.
611
+ """
612
+ self.fit(dataset)
613
+ assert self._sklearn_object is not None
614
+ return self._sklearn_object.embedding_
615
+
581
616
 
582
617
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
583
618
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class EllipticEnvelope(BaseTransformer):
58
70
  r"""An object for detecting outliers in a Gaussian distributed dataset
59
71
  For more details on this class, see [sklearn.covariance.EllipticEnvelope]
@@ -154,7 +166,9 @@ class EllipticEnvelope(BaseTransformer):
154
166
  self.set_label_cols(label_cols)
155
167
  self.set_passthrough_cols(passthrough_cols)
156
168
  self.set_drop_input_cols(drop_input_cols)
157
- self.set_sample_weight_col(sample_weight_col)
169
+ self.set_sample_weight_col(sample_weight_col)
170
+ self._use_external_memory_version = False
171
+ self._batch_size = -1
158
172
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
159
173
 
160
174
  self._deps = list(deps)
@@ -234,11 +248,6 @@ class EllipticEnvelope(BaseTransformer):
234
248
  if isinstance(dataset, DataFrame):
235
249
  session = dataset._session
236
250
  assert session is not None # keep mypy happy
237
- # Validate that key package version in user workspace are supported in snowflake conda channel
238
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
239
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
240
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
241
-
242
251
  # Specify input columns so column pruning will be enforced
243
252
  selected_cols = self._get_active_columns()
244
253
  if len(selected_cols) > 0:
@@ -266,7 +275,9 @@ class EllipticEnvelope(BaseTransformer):
266
275
  label_cols=self.label_cols,
267
276
  sample_weight_col=self.sample_weight_col,
268
277
  autogenerated=self._autogenerated,
269
- subproject=_SUBPROJECT
278
+ subproject=_SUBPROJECT,
279
+ use_external_memory_version=self._use_external_memory_version,
280
+ batch_size=self._batch_size,
270
281
  )
271
282
  self._sklearn_object = model_trainer.train()
272
283
  self._is_fitted = True
@@ -537,6 +548,22 @@ class EllipticEnvelope(BaseTransformer):
537
548
  # each row containing a list of values.
538
549
  expected_dtype = "ARRAY"
539
550
 
551
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
552
+ if expected_dtype == "":
553
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
554
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
555
+ expected_dtype = "ARRAY"
556
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
557
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
558
+ expected_dtype = "ARRAY"
559
+ else:
560
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
561
+ # We can only infer the output types from the input types if the following two statemetns are true:
562
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
563
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
564
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
565
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
566
+
540
567
  output_df = self._batch_inference(
541
568
  dataset=dataset,
542
569
  inference_method="transform",
@@ -552,8 +579,8 @@ class EllipticEnvelope(BaseTransformer):
552
579
 
553
580
  return output_df
554
581
 
555
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
556
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
582
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
583
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
557
584
  """ Perform fit on X and returns labels for X
558
585
  For more details on this function, see [sklearn.covariance.EllipticEnvelope.fit_predict]
559
586
  (https://scikit-learn.org/stable/modules/generated/sklearn.covariance.EllipticEnvelope.html#sklearn.covariance.EllipticEnvelope.fit_predict)
@@ -568,13 +595,21 @@ class EllipticEnvelope(BaseTransformer):
568
595
  Returns:
569
596
  Predicted dataset.
570
597
  """
571
- if False:
572
- self.fit(dataset)
573
- assert self._sklearn_object is not None
574
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
575
- return labels
576
- else:
577
- raise NotImplementedError
598
+ self.fit(dataset)
599
+ assert self._sklearn_object is not None
600
+ return self._sklearn_object.labels_
601
+
602
+
603
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
604
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
605
+ """
606
+ Returns:
607
+ Transformed dataset.
608
+ """
609
+ self.fit(dataset)
610
+ assert self._sklearn_object is not None
611
+ return self._sklearn_object.embedding_
612
+
578
613
 
579
614
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
580
615
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.