snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class KMeans(BaseTransformer):
58
70
  r"""K-Means clustering
59
71
  For more details on this class, see [sklearn.cluster.KMeans]
@@ -201,7 +213,9 @@ class KMeans(BaseTransformer):
201
213
  self.set_label_cols(label_cols)
202
214
  self.set_passthrough_cols(passthrough_cols)
203
215
  self.set_drop_input_cols(drop_input_cols)
204
- self.set_sample_weight_col(sample_weight_col)
216
+ self.set_sample_weight_col(sample_weight_col)
217
+ self._use_external_memory_version = False
218
+ self._batch_size = -1
205
219
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
206
220
 
207
221
  self._deps = list(deps)
@@ -285,11 +299,6 @@ class KMeans(BaseTransformer):
285
299
  if isinstance(dataset, DataFrame):
286
300
  session = dataset._session
287
301
  assert session is not None # keep mypy happy
288
- # Validate that key package version in user workspace are supported in snowflake conda channel
289
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
290
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
291
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
292
-
293
302
  # Specify input columns so column pruning will be enforced
294
303
  selected_cols = self._get_active_columns()
295
304
  if len(selected_cols) > 0:
@@ -317,7 +326,9 @@ class KMeans(BaseTransformer):
317
326
  label_cols=self.label_cols,
318
327
  sample_weight_col=self.sample_weight_col,
319
328
  autogenerated=self._autogenerated,
320
- subproject=_SUBPROJECT
329
+ subproject=_SUBPROJECT,
330
+ use_external_memory_version=self._use_external_memory_version,
331
+ batch_size=self._batch_size,
321
332
  )
322
333
  self._sklearn_object = model_trainer.train()
323
334
  self._is_fitted = True
@@ -590,6 +601,22 @@ class KMeans(BaseTransformer):
590
601
  # each row containing a list of values.
591
602
  expected_dtype = "ARRAY"
592
603
 
604
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
605
+ if expected_dtype == "":
606
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
607
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
608
+ expected_dtype = "ARRAY"
609
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
610
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
611
+ expected_dtype = "ARRAY"
612
+ else:
613
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
614
+ # We can only infer the output types from the input types if the following two statemetns are true:
615
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
616
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
617
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
618
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
619
+
593
620
  output_df = self._batch_inference(
594
621
  dataset=dataset,
595
622
  inference_method="transform",
@@ -605,8 +632,8 @@ class KMeans(BaseTransformer):
605
632
 
606
633
  return output_df
607
634
 
608
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
609
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
635
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
636
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
610
637
  """ Compute cluster centers and predict cluster index for each sample
611
638
  For more details on this function, see [sklearn.cluster.KMeans.fit_predict]
612
639
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit_predict)
@@ -621,13 +648,21 @@ class KMeans(BaseTransformer):
621
648
  Returns:
622
649
  Predicted dataset.
623
650
  """
624
- if True:
625
- self.fit(dataset)
626
- assert self._sklearn_object is not None
627
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
628
- return labels
629
- else:
630
- raise NotImplementedError
651
+ self.fit(dataset)
652
+ assert self._sklearn_object is not None
653
+ return self._sklearn_object.labels_
654
+
655
+
656
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
657
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
658
+ """
659
+ Returns:
660
+ Transformed dataset.
661
+ """
662
+ self.fit(dataset)
663
+ assert self._sklearn_object is not None
664
+ return self._sklearn_object.embedding_
665
+
631
666
 
632
667
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
633
668
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MeanShift(BaseTransformer):
58
70
  r"""Mean shift clustering using a flat kernel
59
71
  For more details on this class, see [sklearn.cluster.MeanShift]
@@ -179,7 +191,9 @@ class MeanShift(BaseTransformer):
179
191
  self.set_label_cols(label_cols)
180
192
  self.set_passthrough_cols(passthrough_cols)
181
193
  self.set_drop_input_cols(drop_input_cols)
182
- self.set_sample_weight_col(sample_weight_col)
194
+ self.set_sample_weight_col(sample_weight_col)
195
+ self._use_external_memory_version = False
196
+ self._batch_size = -1
183
197
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
184
198
 
185
199
  self._deps = list(deps)
@@ -261,11 +275,6 @@ class MeanShift(BaseTransformer):
261
275
  if isinstance(dataset, DataFrame):
262
276
  session = dataset._session
263
277
  assert session is not None # keep mypy happy
264
- # Validate that key package version in user workspace are supported in snowflake conda channel
265
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
266
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
267
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
268
-
269
278
  # Specify input columns so column pruning will be enforced
270
279
  selected_cols = self._get_active_columns()
271
280
  if len(selected_cols) > 0:
@@ -293,7 +302,9 @@ class MeanShift(BaseTransformer):
293
302
  label_cols=self.label_cols,
294
303
  sample_weight_col=self.sample_weight_col,
295
304
  autogenerated=self._autogenerated,
296
- subproject=_SUBPROJECT
305
+ subproject=_SUBPROJECT,
306
+ use_external_memory_version=self._use_external_memory_version,
307
+ batch_size=self._batch_size,
297
308
  )
298
309
  self._sklearn_object = model_trainer.train()
299
310
  self._is_fitted = True
@@ -564,6 +575,22 @@ class MeanShift(BaseTransformer):
564
575
  # each row containing a list of values.
565
576
  expected_dtype = "ARRAY"
566
577
 
578
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
579
+ if expected_dtype == "":
580
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
581
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
582
+ expected_dtype = "ARRAY"
583
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
584
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
585
+ expected_dtype = "ARRAY"
586
+ else:
587
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
588
+ # We can only infer the output types from the input types if the following two statemetns are true:
589
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
590
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
591
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
592
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
593
+
567
594
  output_df = self._batch_inference(
568
595
  dataset=dataset,
569
596
  inference_method="transform",
@@ -579,8 +606,8 @@ class MeanShift(BaseTransformer):
579
606
 
580
607
  return output_df
581
608
 
582
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
583
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
609
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
610
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
584
611
  """ Perform clustering on `X` and returns cluster labels
585
612
  For more details on this function, see [sklearn.cluster.MeanShift.fit_predict]
586
613
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html#sklearn.cluster.MeanShift.fit_predict)
@@ -595,13 +622,21 @@ class MeanShift(BaseTransformer):
595
622
  Returns:
596
623
  Predicted dataset.
597
624
  """
598
- if True:
599
- self.fit(dataset)
600
- assert self._sklearn_object is not None
601
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
602
- return labels
603
- else:
604
- raise NotImplementedError
625
+ self.fit(dataset)
626
+ assert self._sklearn_object is not None
627
+ return self._sklearn_object.labels_
628
+
629
+
630
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
631
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
632
+ """
633
+ Returns:
634
+ Transformed dataset.
635
+ """
636
+ self.fit(dataset)
637
+ assert self._sklearn_object is not None
638
+ return self._sklearn_object.embedding_
639
+
605
640
 
606
641
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
607
642
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MiniBatchKMeans(BaseTransformer):
58
70
  r"""Mini-Batch K-Means clustering
59
71
  For more details on this class, see [sklearn.cluster.MiniBatchKMeans]
@@ -224,7 +236,9 @@ class MiniBatchKMeans(BaseTransformer):
224
236
  self.set_label_cols(label_cols)
225
237
  self.set_passthrough_cols(passthrough_cols)
226
238
  self.set_drop_input_cols(drop_input_cols)
227
- self.set_sample_weight_col(sample_weight_col)
239
+ self.set_sample_weight_col(sample_weight_col)
240
+ self._use_external_memory_version = False
241
+ self._batch_size = -1
228
242
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
229
243
 
230
244
  self._deps = list(deps)
@@ -311,11 +325,6 @@ class MiniBatchKMeans(BaseTransformer):
311
325
  if isinstance(dataset, DataFrame):
312
326
  session = dataset._session
313
327
  assert session is not None # keep mypy happy
314
- # Validate that key package version in user workspace are supported in snowflake conda channel
315
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
316
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
317
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
318
-
319
328
  # Specify input columns so column pruning will be enforced
320
329
  selected_cols = self._get_active_columns()
321
330
  if len(selected_cols) > 0:
@@ -343,7 +352,9 @@ class MiniBatchKMeans(BaseTransformer):
343
352
  label_cols=self.label_cols,
344
353
  sample_weight_col=self.sample_weight_col,
345
354
  autogenerated=self._autogenerated,
346
- subproject=_SUBPROJECT
355
+ subproject=_SUBPROJECT,
356
+ use_external_memory_version=self._use_external_memory_version,
357
+ batch_size=self._batch_size,
347
358
  )
348
359
  self._sklearn_object = model_trainer.train()
349
360
  self._is_fitted = True
@@ -616,6 +627,22 @@ class MiniBatchKMeans(BaseTransformer):
616
627
  # each row containing a list of values.
617
628
  expected_dtype = "ARRAY"
618
629
 
630
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
631
+ if expected_dtype == "":
632
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
633
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
634
+ expected_dtype = "ARRAY"
635
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
636
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
637
+ expected_dtype = "ARRAY"
638
+ else:
639
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
640
+ # We can only infer the output types from the input types if the following two statemetns are true:
641
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
642
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
643
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
644
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
645
+
619
646
  output_df = self._batch_inference(
620
647
  dataset=dataset,
621
648
  inference_method="transform",
@@ -631,8 +658,8 @@ class MiniBatchKMeans(BaseTransformer):
631
658
 
632
659
  return output_df
633
660
 
634
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
635
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
661
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
662
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
636
663
  """ Compute cluster centers and predict cluster index for each sample
637
664
  For more details on this function, see [sklearn.cluster.MiniBatchKMeans.fit_predict]
638
665
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html#sklearn.cluster.MiniBatchKMeans.fit_predict)
@@ -647,13 +674,21 @@ class MiniBatchKMeans(BaseTransformer):
647
674
  Returns:
648
675
  Predicted dataset.
649
676
  """
650
- if True:
651
- self.fit(dataset)
652
- assert self._sklearn_object is not None
653
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
654
- return labels
655
- else:
656
- raise NotImplementedError
677
+ self.fit(dataset)
678
+ assert self._sklearn_object is not None
679
+ return self._sklearn_object.labels_
680
+
681
+
682
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
683
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
684
+ """
685
+ Returns:
686
+ Transformed dataset.
687
+ """
688
+ self.fit(dataset)
689
+ assert self._sklearn_object is not None
690
+ return self._sklearn_object.embedding_
691
+
657
692
 
658
693
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
659
694
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class OPTICS(BaseTransformer):
58
70
  r"""Estimate clustering structure from vector array
59
71
  For more details on this class, see [sklearn.cluster.OPTICS]
@@ -242,7 +254,9 @@ class OPTICS(BaseTransformer):
242
254
  self.set_label_cols(label_cols)
243
255
  self.set_passthrough_cols(passthrough_cols)
244
256
  self.set_drop_input_cols(drop_input_cols)
245
- self.set_sample_weight_col(sample_weight_col)
257
+ self.set_sample_weight_col(sample_weight_col)
258
+ self._use_external_memory_version = False
259
+ self._batch_size = -1
246
260
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
247
261
 
248
262
  self._deps = list(deps)
@@ -331,11 +345,6 @@ class OPTICS(BaseTransformer):
331
345
  if isinstance(dataset, DataFrame):
332
346
  session = dataset._session
333
347
  assert session is not None # keep mypy happy
334
- # Validate that key package version in user workspace are supported in snowflake conda channel
335
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
336
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
337
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
338
-
339
348
  # Specify input columns so column pruning will be enforced
340
349
  selected_cols = self._get_active_columns()
341
350
  if len(selected_cols) > 0:
@@ -363,7 +372,9 @@ class OPTICS(BaseTransformer):
363
372
  label_cols=self.label_cols,
364
373
  sample_weight_col=self.sample_weight_col,
365
374
  autogenerated=self._autogenerated,
366
- subproject=_SUBPROJECT
375
+ subproject=_SUBPROJECT,
376
+ use_external_memory_version=self._use_external_memory_version,
377
+ batch_size=self._batch_size,
367
378
  )
368
379
  self._sklearn_object = model_trainer.train()
369
380
  self._is_fitted = True
@@ -632,6 +643,22 @@ class OPTICS(BaseTransformer):
632
643
  # each row containing a list of values.
633
644
  expected_dtype = "ARRAY"
634
645
 
646
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
647
+ if expected_dtype == "":
648
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
649
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
650
+ expected_dtype = "ARRAY"
651
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
652
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
653
+ expected_dtype = "ARRAY"
654
+ else:
655
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
656
+ # We can only infer the output types from the input types if the following two statemetns are true:
657
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
658
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
659
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
660
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
661
+
635
662
  output_df = self._batch_inference(
636
663
  dataset=dataset,
637
664
  inference_method="transform",
@@ -647,8 +674,8 @@ class OPTICS(BaseTransformer):
647
674
 
648
675
  return output_df
649
676
 
650
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
651
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
677
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
678
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
652
679
  """ Perform clustering on `X` and returns cluster labels
653
680
  For more details on this function, see [sklearn.cluster.OPTICS.fit_predict]
654
681
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html#sklearn.cluster.OPTICS.fit_predict)
@@ -663,13 +690,21 @@ class OPTICS(BaseTransformer):
663
690
  Returns:
664
691
  Predicted dataset.
665
692
  """
666
- if True:
667
- self.fit(dataset)
668
- assert self._sklearn_object is not None
669
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
670
- return labels
671
- else:
672
- raise NotImplementedError
693
+ self.fit(dataset)
694
+ assert self._sklearn_object is not None
695
+ return self._sklearn_object.labels_
696
+
697
+
698
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
699
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
700
+ """
701
+ Returns:
702
+ Transformed dataset.
703
+ """
704
+ self.fit(dataset)
705
+ assert self._sklearn_object is not None
706
+ return self._sklearn_object.embedding_
707
+
673
708
 
674
709
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
675
710
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SpectralBiclustering(BaseTransformer):
58
70
  r"""Spectral biclustering (Kluger, 2003)
59
71
  For more details on this class, see [sklearn.cluster.SpectralBiclustering]
@@ -184,7 +196,9 @@ class SpectralBiclustering(BaseTransformer):
184
196
  self.set_label_cols(label_cols)
185
197
  self.set_passthrough_cols(passthrough_cols)
186
198
  self.set_drop_input_cols(drop_input_cols)
187
- self.set_sample_weight_col(sample_weight_col)
199
+ self.set_sample_weight_col(sample_weight_col)
200
+ self._use_external_memory_version = False
201
+ self._batch_size = -1
188
202
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
189
203
 
190
204
  self._deps = list(deps)
@@ -269,11 +283,6 @@ class SpectralBiclustering(BaseTransformer):
269
283
  if isinstance(dataset, DataFrame):
270
284
  session = dataset._session
271
285
  assert session is not None # keep mypy happy
272
- # Validate that key package version in user workspace are supported in snowflake conda channel
273
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
274
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
275
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
276
-
277
286
  # Specify input columns so column pruning will be enforced
278
287
  selected_cols = self._get_active_columns()
279
288
  if len(selected_cols) > 0:
@@ -301,7 +310,9 @@ class SpectralBiclustering(BaseTransformer):
301
310
  label_cols=self.label_cols,
302
311
  sample_weight_col=self.sample_weight_col,
303
312
  autogenerated=self._autogenerated,
304
- subproject=_SUBPROJECT
313
+ subproject=_SUBPROJECT,
314
+ use_external_memory_version=self._use_external_memory_version,
315
+ batch_size=self._batch_size,
305
316
  )
306
317
  self._sklearn_object = model_trainer.train()
307
318
  self._is_fitted = True
@@ -570,6 +581,22 @@ class SpectralBiclustering(BaseTransformer):
570
581
  # each row containing a list of values.
571
582
  expected_dtype = "ARRAY"
572
583
 
584
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
585
+ if expected_dtype == "":
586
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
587
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
588
+ expected_dtype = "ARRAY"
589
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
590
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
591
+ expected_dtype = "ARRAY"
592
+ else:
593
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
594
+ # We can only infer the output types from the input types if the following two statemetns are true:
595
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
596
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
597
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
598
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
599
+
573
600
  output_df = self._batch_inference(
574
601
  dataset=dataset,
575
602
  inference_method="transform",
@@ -585,8 +612,8 @@ class SpectralBiclustering(BaseTransformer):
585
612
 
586
613
  return output_df
587
614
 
588
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
589
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
615
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
616
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
590
617
  """ Method not supported for this class.
591
618
 
592
619
 
@@ -599,13 +626,21 @@ class SpectralBiclustering(BaseTransformer):
599
626
  Returns:
600
627
  Predicted dataset.
601
628
  """
602
- if False:
603
- self.fit(dataset)
604
- assert self._sklearn_object is not None
605
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
606
- return labels
607
- else:
608
- raise NotImplementedError
629
+ self.fit(dataset)
630
+ assert self._sklearn_object is not None
631
+ return self._sklearn_object.labels_
632
+
633
+
634
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
635
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
636
+ """
637
+ Returns:
638
+ Transformed dataset.
639
+ """
640
+ self.fit(dataset)
641
+ assert self._sklearn_object is not None
642
+ return self._sklearn_object.embedding_
643
+
609
644
 
610
645
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
611
646
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.