snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class EmpiricalCovariance(BaseTransformer):
58
70
  r"""Maximum likelihood covariance estimator
59
71
  For more details on this class, see [sklearn.covariance.EmpiricalCovariance]
@@ -133,7 +145,9 @@ class EmpiricalCovariance(BaseTransformer):
133
145
  self.set_label_cols(label_cols)
134
146
  self.set_passthrough_cols(passthrough_cols)
135
147
  self.set_drop_input_cols(drop_input_cols)
136
- self.set_sample_weight_col(sample_weight_col)
148
+ self.set_sample_weight_col(sample_weight_col)
149
+ self._use_external_memory_version = False
150
+ self._batch_size = -1
137
151
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
138
152
 
139
153
  self._deps = list(deps)
@@ -210,11 +224,6 @@ class EmpiricalCovariance(BaseTransformer):
210
224
  if isinstance(dataset, DataFrame):
211
225
  session = dataset._session
212
226
  assert session is not None # keep mypy happy
213
- # Validate that key package version in user workspace are supported in snowflake conda channel
214
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
215
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
216
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
217
-
218
227
  # Specify input columns so column pruning will be enforced
219
228
  selected_cols = self._get_active_columns()
220
229
  if len(selected_cols) > 0:
@@ -242,7 +251,9 @@ class EmpiricalCovariance(BaseTransformer):
242
251
  label_cols=self.label_cols,
243
252
  sample_weight_col=self.sample_weight_col,
244
253
  autogenerated=self._autogenerated,
245
- subproject=_SUBPROJECT
254
+ subproject=_SUBPROJECT,
255
+ use_external_memory_version=self._use_external_memory_version,
256
+ batch_size=self._batch_size,
246
257
  )
247
258
  self._sklearn_object = model_trainer.train()
248
259
  self._is_fitted = True
@@ -511,6 +522,22 @@ class EmpiricalCovariance(BaseTransformer):
511
522
  # each row containing a list of values.
512
523
  expected_dtype = "ARRAY"
513
524
 
525
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
526
+ if expected_dtype == "":
527
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
528
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
529
+ expected_dtype = "ARRAY"
530
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
531
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
532
+ expected_dtype = "ARRAY"
533
+ else:
534
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
535
+ # We can only infer the output types from the input types if the following two statemetns are true:
536
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
537
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
538
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
539
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
540
+
514
541
  output_df = self._batch_inference(
515
542
  dataset=dataset,
516
543
  inference_method="transform",
@@ -526,8 +553,8 @@ class EmpiricalCovariance(BaseTransformer):
526
553
 
527
554
  return output_df
528
555
 
529
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
530
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
556
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
557
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
531
558
  """ Method not supported for this class.
532
559
 
533
560
 
@@ -540,13 +567,21 @@ class EmpiricalCovariance(BaseTransformer):
540
567
  Returns:
541
568
  Predicted dataset.
542
569
  """
543
- if False:
544
- self.fit(dataset)
545
- assert self._sklearn_object is not None
546
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
547
- return labels
548
- else:
549
- raise NotImplementedError
570
+ self.fit(dataset)
571
+ assert self._sklearn_object is not None
572
+ return self._sklearn_object.labels_
573
+
574
+
575
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
576
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
577
+ """
578
+ Returns:
579
+ Transformed dataset.
580
+ """
581
+ self.fit(dataset)
582
+ assert self._sklearn_object is not None
583
+ return self._sklearn_object.embedding_
584
+
550
585
 
551
586
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
552
587
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GraphicalLasso(BaseTransformer):
58
70
  r"""Sparse inverse covariance estimation with an l1-penalized estimator
59
71
  For more details on this class, see [sklearn.covariance.GraphicalLasso]
@@ -174,7 +186,9 @@ class GraphicalLasso(BaseTransformer):
174
186
  self.set_label_cols(label_cols)
175
187
  self.set_passthrough_cols(passthrough_cols)
176
188
  self.set_drop_input_cols(drop_input_cols)
177
- self.set_sample_weight_col(sample_weight_col)
189
+ self.set_sample_weight_col(sample_weight_col)
190
+ self._use_external_memory_version = False
191
+ self._batch_size = -1
178
192
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
179
193
 
180
194
  self._deps = list(deps)
@@ -258,11 +272,6 @@ class GraphicalLasso(BaseTransformer):
258
272
  if isinstance(dataset, DataFrame):
259
273
  session = dataset._session
260
274
  assert session is not None # keep mypy happy
261
- # Validate that key package version in user workspace are supported in snowflake conda channel
262
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
263
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
264
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
265
-
266
275
  # Specify input columns so column pruning will be enforced
267
276
  selected_cols = self._get_active_columns()
268
277
  if len(selected_cols) > 0:
@@ -290,7 +299,9 @@ class GraphicalLasso(BaseTransformer):
290
299
  label_cols=self.label_cols,
291
300
  sample_weight_col=self.sample_weight_col,
292
301
  autogenerated=self._autogenerated,
293
- subproject=_SUBPROJECT
302
+ subproject=_SUBPROJECT,
303
+ use_external_memory_version=self._use_external_memory_version,
304
+ batch_size=self._batch_size,
294
305
  )
295
306
  self._sklearn_object = model_trainer.train()
296
307
  self._is_fitted = True
@@ -559,6 +570,22 @@ class GraphicalLasso(BaseTransformer):
559
570
  # each row containing a list of values.
560
571
  expected_dtype = "ARRAY"
561
572
 
573
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
574
+ if expected_dtype == "":
575
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
576
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
577
+ expected_dtype = "ARRAY"
578
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
579
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
580
+ expected_dtype = "ARRAY"
581
+ else:
582
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
583
+ # We can only infer the output types from the input types if the following two statemetns are true:
584
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
585
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
586
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
587
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
588
+
562
589
  output_df = self._batch_inference(
563
590
  dataset=dataset,
564
591
  inference_method="transform",
@@ -574,8 +601,8 @@ class GraphicalLasso(BaseTransformer):
574
601
 
575
602
  return output_df
576
603
 
577
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
578
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
604
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
605
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
579
606
  """ Method not supported for this class.
580
607
 
581
608
 
@@ -588,13 +615,21 @@ class GraphicalLasso(BaseTransformer):
588
615
  Returns:
589
616
  Predicted dataset.
590
617
  """
591
- if False:
592
- self.fit(dataset)
593
- assert self._sklearn_object is not None
594
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
595
- return labels
596
- else:
597
- raise NotImplementedError
618
+ self.fit(dataset)
619
+ assert self._sklearn_object is not None
620
+ return self._sklearn_object.labels_
621
+
622
+
623
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
624
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
625
+ """
626
+ Returns:
627
+ Transformed dataset.
628
+ """
629
+ self.fit(dataset)
630
+ assert self._sklearn_object is not None
631
+ return self._sklearn_object.embedding_
632
+
598
633
 
599
634
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
600
635
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GraphicalLassoCV(BaseTransformer):
58
70
  r"""Sparse inverse covariance w/ cross-validated choice of the l1 penalty
59
71
  For more details on this class, see [sklearn.covariance.GraphicalLassoCV]
@@ -198,7 +210,9 @@ class GraphicalLassoCV(BaseTransformer):
198
210
  self.set_label_cols(label_cols)
199
211
  self.set_passthrough_cols(passthrough_cols)
200
212
  self.set_drop_input_cols(drop_input_cols)
201
- self.set_sample_weight_col(sample_weight_col)
213
+ self.set_sample_weight_col(sample_weight_col)
214
+ self._use_external_memory_version = False
215
+ self._batch_size = -1
202
216
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
203
217
 
204
218
  self._deps = list(deps)
@@ -284,11 +298,6 @@ class GraphicalLassoCV(BaseTransformer):
284
298
  if isinstance(dataset, DataFrame):
285
299
  session = dataset._session
286
300
  assert session is not None # keep mypy happy
287
- # Validate that key package version in user workspace are supported in snowflake conda channel
288
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
289
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
290
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
291
-
292
301
  # Specify input columns so column pruning will be enforced
293
302
  selected_cols = self._get_active_columns()
294
303
  if len(selected_cols) > 0:
@@ -316,7 +325,9 @@ class GraphicalLassoCV(BaseTransformer):
316
325
  label_cols=self.label_cols,
317
326
  sample_weight_col=self.sample_weight_col,
318
327
  autogenerated=self._autogenerated,
319
- subproject=_SUBPROJECT
328
+ subproject=_SUBPROJECT,
329
+ use_external_memory_version=self._use_external_memory_version,
330
+ batch_size=self._batch_size,
320
331
  )
321
332
  self._sklearn_object = model_trainer.train()
322
333
  self._is_fitted = True
@@ -585,6 +596,22 @@ class GraphicalLassoCV(BaseTransformer):
585
596
  # each row containing a list of values.
586
597
  expected_dtype = "ARRAY"
587
598
 
599
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
600
+ if expected_dtype == "":
601
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
602
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
603
+ expected_dtype = "ARRAY"
604
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
605
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
606
+ expected_dtype = "ARRAY"
607
+ else:
608
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
609
+ # We can only infer the output types from the input types if the following two statemetns are true:
610
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
611
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
612
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
613
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
614
+
588
615
  output_df = self._batch_inference(
589
616
  dataset=dataset,
590
617
  inference_method="transform",
@@ -600,8 +627,8 @@ class GraphicalLassoCV(BaseTransformer):
600
627
 
601
628
  return output_df
602
629
 
603
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
604
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
630
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
631
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
605
632
  """ Method not supported for this class.
606
633
 
607
634
 
@@ -614,13 +641,21 @@ class GraphicalLassoCV(BaseTransformer):
614
641
  Returns:
615
642
  Predicted dataset.
616
643
  """
617
- if False:
618
- self.fit(dataset)
619
- assert self._sklearn_object is not None
620
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
621
- return labels
622
- else:
623
- raise NotImplementedError
644
+ self.fit(dataset)
645
+ assert self._sklearn_object is not None
646
+ return self._sklearn_object.labels_
647
+
648
+
649
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
650
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
651
+ """
652
+ Returns:
653
+ Transformed dataset.
654
+ """
655
+ self.fit(dataset)
656
+ assert self._sklearn_object is not None
657
+ return self._sklearn_object.embedding_
658
+
624
659
 
625
660
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
626
661
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LedoitWolf(BaseTransformer):
58
70
  r"""LedoitWolf Estimator
59
71
  For more details on this class, see [sklearn.covariance.LedoitWolf]
@@ -139,7 +151,9 @@ class LedoitWolf(BaseTransformer):
139
151
  self.set_label_cols(label_cols)
140
152
  self.set_passthrough_cols(passthrough_cols)
141
153
  self.set_drop_input_cols(drop_input_cols)
142
- self.set_sample_weight_col(sample_weight_col)
154
+ self.set_sample_weight_col(sample_weight_col)
155
+ self._use_external_memory_version = False
156
+ self._batch_size = -1
143
157
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
144
158
 
145
159
  self._deps = list(deps)
@@ -217,11 +231,6 @@ class LedoitWolf(BaseTransformer):
217
231
  if isinstance(dataset, DataFrame):
218
232
  session = dataset._session
219
233
  assert session is not None # keep mypy happy
220
- # Validate that key package version in user workspace are supported in snowflake conda channel
221
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
222
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
223
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
224
-
225
234
  # Specify input columns so column pruning will be enforced
226
235
  selected_cols = self._get_active_columns()
227
236
  if len(selected_cols) > 0:
@@ -249,7 +258,9 @@ class LedoitWolf(BaseTransformer):
249
258
  label_cols=self.label_cols,
250
259
  sample_weight_col=self.sample_weight_col,
251
260
  autogenerated=self._autogenerated,
252
- subproject=_SUBPROJECT
261
+ subproject=_SUBPROJECT,
262
+ use_external_memory_version=self._use_external_memory_version,
263
+ batch_size=self._batch_size,
253
264
  )
254
265
  self._sklearn_object = model_trainer.train()
255
266
  self._is_fitted = True
@@ -518,6 +529,22 @@ class LedoitWolf(BaseTransformer):
518
529
  # each row containing a list of values.
519
530
  expected_dtype = "ARRAY"
520
531
 
532
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
533
+ if expected_dtype == "":
534
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
535
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
536
+ expected_dtype = "ARRAY"
537
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
538
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
539
+ expected_dtype = "ARRAY"
540
+ else:
541
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
542
+ # We can only infer the output types from the input types if the following two statemetns are true:
543
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
544
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
545
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
546
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
547
+
521
548
  output_df = self._batch_inference(
522
549
  dataset=dataset,
523
550
  inference_method="transform",
@@ -533,8 +560,8 @@ class LedoitWolf(BaseTransformer):
533
560
 
534
561
  return output_df
535
562
 
536
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
537
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
563
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
564
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
538
565
  """ Method not supported for this class.
539
566
 
540
567
 
@@ -547,13 +574,21 @@ class LedoitWolf(BaseTransformer):
547
574
  Returns:
548
575
  Predicted dataset.
549
576
  """
550
- if False:
551
- self.fit(dataset)
552
- assert self._sklearn_object is not None
553
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
554
- return labels
555
- else:
556
- raise NotImplementedError
577
+ self.fit(dataset)
578
+ assert self._sklearn_object is not None
579
+ return self._sklearn_object.labels_
580
+
581
+
582
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
583
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
584
+ """
585
+ Returns:
586
+ Transformed dataset.
587
+ """
588
+ self.fit(dataset)
589
+ assert self._sklearn_object is not None
590
+ return self._sklearn_object.embedding_
591
+
557
592
 
558
593
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
559
594
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MinCovDet(BaseTransformer):
58
70
  r"""Minimum Covariance Determinant (MCD): robust estimator of covariance
59
71
  For more details on this class, see [sklearn.covariance.MinCovDet]
@@ -150,7 +162,9 @@ class MinCovDet(BaseTransformer):
150
162
  self.set_label_cols(label_cols)
151
163
  self.set_passthrough_cols(passthrough_cols)
152
164
  self.set_drop_input_cols(drop_input_cols)
153
- self.set_sample_weight_col(sample_weight_col)
165
+ self.set_sample_weight_col(sample_weight_col)
166
+ self._use_external_memory_version = False
167
+ self._batch_size = -1
154
168
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
155
169
 
156
170
  self._deps = list(deps)
@@ -229,11 +243,6 @@ class MinCovDet(BaseTransformer):
229
243
  if isinstance(dataset, DataFrame):
230
244
  session = dataset._session
231
245
  assert session is not None # keep mypy happy
232
- # Validate that key package version in user workspace are supported in snowflake conda channel
233
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
234
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
235
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
236
-
237
246
  # Specify input columns so column pruning will be enforced
238
247
  selected_cols = self._get_active_columns()
239
248
  if len(selected_cols) > 0:
@@ -261,7 +270,9 @@ class MinCovDet(BaseTransformer):
261
270
  label_cols=self.label_cols,
262
271
  sample_weight_col=self.sample_weight_col,
263
272
  autogenerated=self._autogenerated,
264
- subproject=_SUBPROJECT
273
+ subproject=_SUBPROJECT,
274
+ use_external_memory_version=self._use_external_memory_version,
275
+ batch_size=self._batch_size,
265
276
  )
266
277
  self._sklearn_object = model_trainer.train()
267
278
  self._is_fitted = True
@@ -530,6 +541,22 @@ class MinCovDet(BaseTransformer):
530
541
  # each row containing a list of values.
531
542
  expected_dtype = "ARRAY"
532
543
 
544
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
545
+ if expected_dtype == "":
546
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
547
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
548
+ expected_dtype = "ARRAY"
549
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
550
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ else:
553
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
554
+ # We can only infer the output types from the input types if the following two statemetns are true:
555
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
556
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
557
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
558
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
559
+
533
560
  output_df = self._batch_inference(
534
561
  dataset=dataset,
535
562
  inference_method="transform",
@@ -545,8 +572,8 @@ class MinCovDet(BaseTransformer):
545
572
 
546
573
  return output_df
547
574
 
548
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
549
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
575
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
576
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
550
577
  """ Method not supported for this class.
551
578
 
552
579
 
@@ -559,13 +586,21 @@ class MinCovDet(BaseTransformer):
559
586
  Returns:
560
587
  Predicted dataset.
561
588
  """
562
- if False:
563
- self.fit(dataset)
564
- assert self._sklearn_object is not None
565
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
566
- return labels
567
- else:
568
- raise NotImplementedError
589
+ self.fit(dataset)
590
+ assert self._sklearn_object is not None
591
+ return self._sklearn_object.labels_
592
+
593
+
594
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
595
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
596
+ """
597
+ Returns:
598
+ Transformed dataset.
599
+ """
600
+ self.fit(dataset)
601
+ assert self._sklearn_object is not None
602
+ return self._sklearn_object.embedding_
603
+
569
604
 
570
605
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
571
606
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.