snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class AgglomerativeClustering(BaseTransformer):
58
70
  r"""Agglomerative Clustering
59
71
  For more details on this class, see [sklearn.cluster.AgglomerativeClustering]
@@ -199,7 +211,9 @@ class AgglomerativeClustering(BaseTransformer):
199
211
  self.set_label_cols(label_cols)
200
212
  self.set_passthrough_cols(passthrough_cols)
201
213
  self.set_drop_input_cols(drop_input_cols)
202
- self.set_sample_weight_col(sample_weight_col)
214
+ self.set_sample_weight_col(sample_weight_col)
215
+ self._use_external_memory_version = False
216
+ self._batch_size = -1
203
217
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
204
218
 
205
219
  self._deps = list(deps)
@@ -283,11 +297,6 @@ class AgglomerativeClustering(BaseTransformer):
283
297
  if isinstance(dataset, DataFrame):
284
298
  session = dataset._session
285
299
  assert session is not None # keep mypy happy
286
- # Validate that key package version in user workspace are supported in snowflake conda channel
287
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
288
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
289
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
290
-
291
300
  # Specify input columns so column pruning will be enforced
292
301
  selected_cols = self._get_active_columns()
293
302
  if len(selected_cols) > 0:
@@ -315,7 +324,9 @@ class AgglomerativeClustering(BaseTransformer):
315
324
  label_cols=self.label_cols,
316
325
  sample_weight_col=self.sample_weight_col,
317
326
  autogenerated=self._autogenerated,
318
- subproject=_SUBPROJECT
327
+ subproject=_SUBPROJECT,
328
+ use_external_memory_version=self._use_external_memory_version,
329
+ batch_size=self._batch_size,
319
330
  )
320
331
  self._sklearn_object = model_trainer.train()
321
332
  self._is_fitted = True
@@ -584,6 +595,22 @@ class AgglomerativeClustering(BaseTransformer):
584
595
  # each row containing a list of values.
585
596
  expected_dtype = "ARRAY"
586
597
 
598
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
599
+ if expected_dtype == "":
600
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
601
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
602
+ expected_dtype = "ARRAY"
603
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
604
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
605
+ expected_dtype = "ARRAY"
606
+ else:
607
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
608
+ # We can only infer the output types from the input types if the following two statemetns are true:
609
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
610
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
611
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
612
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
613
+
587
614
  output_df = self._batch_inference(
588
615
  dataset=dataset,
589
616
  inference_method="transform",
@@ -599,8 +626,8 @@ class AgglomerativeClustering(BaseTransformer):
599
626
 
600
627
  return output_df
601
628
 
602
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
603
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
629
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
630
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
604
631
  """ Fit and return the result of each sample's clustering assignment
605
632
  For more details on this function, see [sklearn.cluster.AgglomerativeClustering.fit_predict]
606
633
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html#sklearn.cluster.AgglomerativeClustering.fit_predict)
@@ -615,13 +642,21 @@ class AgglomerativeClustering(BaseTransformer):
615
642
  Returns:
616
643
  Predicted dataset.
617
644
  """
618
- if True:
619
- self.fit(dataset)
620
- assert self._sklearn_object is not None
621
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
622
- return labels
623
- else:
624
- raise NotImplementedError
645
+ self.fit(dataset)
646
+ assert self._sklearn_object is not None
647
+ return self._sklearn_object.labels_
648
+
649
+
650
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
651
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
652
+ """
653
+ Returns:
654
+ Transformed dataset.
655
+ """
656
+ self.fit(dataset)
657
+ assert self._sklearn_object is not None
658
+ return self._sklearn_object.embedding_
659
+
625
660
 
626
661
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
627
662
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class Birch(BaseTransformer):
58
70
  r"""Implements the BIRCH clustering algorithm
59
71
  For more details on this class, see [sklearn.cluster.Birch]
@@ -161,7 +173,9 @@ class Birch(BaseTransformer):
161
173
  self.set_label_cols(label_cols)
162
174
  self.set_passthrough_cols(passthrough_cols)
163
175
  self.set_drop_input_cols(drop_input_cols)
164
- self.set_sample_weight_col(sample_weight_col)
176
+ self.set_sample_weight_col(sample_weight_col)
177
+ self._use_external_memory_version = False
178
+ self._batch_size = -1
165
179
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
166
180
 
167
181
  self._deps = list(deps)
@@ -241,11 +255,6 @@ class Birch(BaseTransformer):
241
255
  if isinstance(dataset, DataFrame):
242
256
  session = dataset._session
243
257
  assert session is not None # keep mypy happy
244
- # Validate that key package version in user workspace are supported in snowflake conda channel
245
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
246
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
247
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
248
-
249
258
  # Specify input columns so column pruning will be enforced
250
259
  selected_cols = self._get_active_columns()
251
260
  if len(selected_cols) > 0:
@@ -273,7 +282,9 @@ class Birch(BaseTransformer):
273
282
  label_cols=self.label_cols,
274
283
  sample_weight_col=self.sample_weight_col,
275
284
  autogenerated=self._autogenerated,
276
- subproject=_SUBPROJECT
285
+ subproject=_SUBPROJECT,
286
+ use_external_memory_version=self._use_external_memory_version,
287
+ batch_size=self._batch_size,
277
288
  )
278
289
  self._sklearn_object = model_trainer.train()
279
290
  self._is_fitted = True
@@ -546,6 +557,22 @@ class Birch(BaseTransformer):
546
557
  # each row containing a list of values.
547
558
  expected_dtype = "ARRAY"
548
559
 
560
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
561
+ if expected_dtype == "":
562
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
563
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
566
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
567
+ expected_dtype = "ARRAY"
568
+ else:
569
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
570
+ # We can only infer the output types from the input types if the following two statemetns are true:
571
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
572
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
573
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
574
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
575
+
549
576
  output_df = self._batch_inference(
550
577
  dataset=dataset,
551
578
  inference_method="transform",
@@ -561,8 +588,8 @@ class Birch(BaseTransformer):
561
588
 
562
589
  return output_df
563
590
 
564
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
565
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
591
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
592
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
566
593
  """ Perform clustering on `X` and returns cluster labels
567
594
  For more details on this function, see [sklearn.cluster.Birch.fit_predict]
568
595
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch.fit_predict)
@@ -577,13 +604,21 @@ class Birch(BaseTransformer):
577
604
  Returns:
578
605
  Predicted dataset.
579
606
  """
580
- if True:
581
- self.fit(dataset)
582
- assert self._sklearn_object is not None
583
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
584
- return labels
585
- else:
586
- raise NotImplementedError
607
+ self.fit(dataset)
608
+ assert self._sklearn_object is not None
609
+ return self._sklearn_object.labels_
610
+
611
+
612
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
613
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
614
+ """
615
+ Returns:
616
+ Transformed dataset.
617
+ """
618
+ self.fit(dataset)
619
+ assert self._sklearn_object is not None
620
+ return self._sklearn_object.embedding_
621
+
587
622
 
588
623
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
589
624
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class BisectingKMeans(BaseTransformer):
58
70
  r"""Bisecting K-Means clustering
59
71
  For more details on this class, see [sklearn.cluster.BisectingKMeans]
@@ -205,7 +217,9 @@ class BisectingKMeans(BaseTransformer):
205
217
  self.set_label_cols(label_cols)
206
218
  self.set_passthrough_cols(passthrough_cols)
207
219
  self.set_drop_input_cols(drop_input_cols)
208
- self.set_sample_weight_col(sample_weight_col)
220
+ self.set_sample_weight_col(sample_weight_col)
221
+ self._use_external_memory_version = False
222
+ self._batch_size = -1
209
223
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
210
224
 
211
225
  self._deps = list(deps)
@@ -290,11 +304,6 @@ class BisectingKMeans(BaseTransformer):
290
304
  if isinstance(dataset, DataFrame):
291
305
  session = dataset._session
292
306
  assert session is not None # keep mypy happy
293
- # Validate that key package version in user workspace are supported in snowflake conda channel
294
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
295
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
296
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
297
-
298
307
  # Specify input columns so column pruning will be enforced
299
308
  selected_cols = self._get_active_columns()
300
309
  if len(selected_cols) > 0:
@@ -322,7 +331,9 @@ class BisectingKMeans(BaseTransformer):
322
331
  label_cols=self.label_cols,
323
332
  sample_weight_col=self.sample_weight_col,
324
333
  autogenerated=self._autogenerated,
325
- subproject=_SUBPROJECT
334
+ subproject=_SUBPROJECT,
335
+ use_external_memory_version=self._use_external_memory_version,
336
+ batch_size=self._batch_size,
326
337
  )
327
338
  self._sklearn_object = model_trainer.train()
328
339
  self._is_fitted = True
@@ -595,6 +606,22 @@ class BisectingKMeans(BaseTransformer):
595
606
  # each row containing a list of values.
596
607
  expected_dtype = "ARRAY"
597
608
 
609
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
610
+ if expected_dtype == "":
611
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
612
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
613
+ expected_dtype = "ARRAY"
614
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
615
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
616
+ expected_dtype = "ARRAY"
617
+ else:
618
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
619
+ # We can only infer the output types from the input types if the following two statemetns are true:
620
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
621
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
622
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
623
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
624
+
598
625
  output_df = self._batch_inference(
599
626
  dataset=dataset,
600
627
  inference_method="transform",
@@ -610,8 +637,8 @@ class BisectingKMeans(BaseTransformer):
610
637
 
611
638
  return output_df
612
639
 
613
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
614
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
640
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
641
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
615
642
  """ Compute cluster centers and predict cluster index for each sample
616
643
  For more details on this function, see [sklearn.cluster.BisectingKMeans.fit_predict]
617
644
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.BisectingKMeans.html#sklearn.cluster.BisectingKMeans.fit_predict)
@@ -626,13 +653,21 @@ class BisectingKMeans(BaseTransformer):
626
653
  Returns:
627
654
  Predicted dataset.
628
655
  """
629
- if True:
630
- self.fit(dataset)
631
- assert self._sklearn_object is not None
632
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
633
- return labels
634
- else:
635
- raise NotImplementedError
656
+ self.fit(dataset)
657
+ assert self._sklearn_object is not None
658
+ return self._sklearn_object.labels_
659
+
660
+
661
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
662
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
663
+ """
664
+ Returns:
665
+ Transformed dataset.
666
+ """
667
+ self.fit(dataset)
668
+ assert self._sklearn_object is not None
669
+ return self._sklearn_object.embedding_
670
+
636
671
 
637
672
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
638
673
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class DBSCAN(BaseTransformer):
58
70
  r"""Perform DBSCAN clustering from vector array or distance matrix
59
71
  For more details on this class, see [sklearn.cluster.DBSCAN]
@@ -175,7 +187,9 @@ class DBSCAN(BaseTransformer):
175
187
  self.set_label_cols(label_cols)
176
188
  self.set_passthrough_cols(passthrough_cols)
177
189
  self.set_drop_input_cols(drop_input_cols)
178
- self.set_sample_weight_col(sample_weight_col)
190
+ self.set_sample_weight_col(sample_weight_col)
191
+ self._use_external_memory_version = False
192
+ self._batch_size = -1
179
193
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
180
194
 
181
195
  self._deps = list(deps)
@@ -258,11 +272,6 @@ class DBSCAN(BaseTransformer):
258
272
  if isinstance(dataset, DataFrame):
259
273
  session = dataset._session
260
274
  assert session is not None # keep mypy happy
261
- # Validate that key package version in user workspace are supported in snowflake conda channel
262
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
263
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
264
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
265
-
266
275
  # Specify input columns so column pruning will be enforced
267
276
  selected_cols = self._get_active_columns()
268
277
  if len(selected_cols) > 0:
@@ -290,7 +299,9 @@ class DBSCAN(BaseTransformer):
290
299
  label_cols=self.label_cols,
291
300
  sample_weight_col=self.sample_weight_col,
292
301
  autogenerated=self._autogenerated,
293
- subproject=_SUBPROJECT
302
+ subproject=_SUBPROJECT,
303
+ use_external_memory_version=self._use_external_memory_version,
304
+ batch_size=self._batch_size,
294
305
  )
295
306
  self._sklearn_object = model_trainer.train()
296
307
  self._is_fitted = True
@@ -559,6 +570,22 @@ class DBSCAN(BaseTransformer):
559
570
  # each row containing a list of values.
560
571
  expected_dtype = "ARRAY"
561
572
 
573
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
574
+ if expected_dtype == "":
575
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
576
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
577
+ expected_dtype = "ARRAY"
578
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
579
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
580
+ expected_dtype = "ARRAY"
581
+ else:
582
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
583
+ # We can only infer the output types from the input types if the following two statemetns are true:
584
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
585
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
586
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
587
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
588
+
562
589
  output_df = self._batch_inference(
563
590
  dataset=dataset,
564
591
  inference_method="transform",
@@ -574,8 +601,8 @@ class DBSCAN(BaseTransformer):
574
601
 
575
602
  return output_df
576
603
 
577
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
578
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
604
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
605
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
579
606
  """ Compute clusters from a data or distance matrix and predict labels
580
607
  For more details on this function, see [sklearn.cluster.DBSCAN.fit_predict]
581
608
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html#sklearn.cluster.DBSCAN.fit_predict)
@@ -590,13 +617,21 @@ class DBSCAN(BaseTransformer):
590
617
  Returns:
591
618
  Predicted dataset.
592
619
  """
593
- if True:
594
- self.fit(dataset)
595
- assert self._sklearn_object is not None
596
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
597
- return labels
598
- else:
599
- raise NotImplementedError
620
+ self.fit(dataset)
621
+ assert self._sklearn_object is not None
622
+ return self._sklearn_object.labels_
623
+
624
+
625
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
626
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
627
+ """
628
+ Returns:
629
+ Transformed dataset.
630
+ """
631
+ self.fit(dataset)
632
+ assert self._sklearn_object is not None
633
+ return self._sklearn_object.embedding_
634
+
600
635
 
601
636
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
602
637
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class FeatureAgglomeration(BaseTransformer):
58
70
  r"""Agglomerate features
59
71
  For more details on this class, see [sklearn.cluster.FeatureAgglomeration]
@@ -205,7 +217,9 @@ class FeatureAgglomeration(BaseTransformer):
205
217
  self.set_label_cols(label_cols)
206
218
  self.set_passthrough_cols(passthrough_cols)
207
219
  self.set_drop_input_cols(drop_input_cols)
208
- self.set_sample_weight_col(sample_weight_col)
220
+ self.set_sample_weight_col(sample_weight_col)
221
+ self._use_external_memory_version = False
222
+ self._batch_size = -1
209
223
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
210
224
 
211
225
  self._deps = list(deps)
@@ -290,11 +304,6 @@ class FeatureAgglomeration(BaseTransformer):
290
304
  if isinstance(dataset, DataFrame):
291
305
  session = dataset._session
292
306
  assert session is not None # keep mypy happy
293
- # Validate that key package version in user workspace are supported in snowflake conda channel
294
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
295
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
296
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
297
-
298
307
  # Specify input columns so column pruning will be enforced
299
308
  selected_cols = self._get_active_columns()
300
309
  if len(selected_cols) > 0:
@@ -322,7 +331,9 @@ class FeatureAgglomeration(BaseTransformer):
322
331
  label_cols=self.label_cols,
323
332
  sample_weight_col=self.sample_weight_col,
324
333
  autogenerated=self._autogenerated,
325
- subproject=_SUBPROJECT
334
+ subproject=_SUBPROJECT,
335
+ use_external_memory_version=self._use_external_memory_version,
336
+ batch_size=self._batch_size,
326
337
  )
327
338
  self._sklearn_object = model_trainer.train()
328
339
  self._is_fitted = True
@@ -593,6 +604,22 @@ class FeatureAgglomeration(BaseTransformer):
593
604
  # each row containing a list of values.
594
605
  expected_dtype = "ARRAY"
595
606
 
607
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
608
+ if expected_dtype == "":
609
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
610
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
611
+ expected_dtype = "ARRAY"
612
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
613
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
614
+ expected_dtype = "ARRAY"
615
+ else:
616
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
617
+ # We can only infer the output types from the input types if the following two statemetns are true:
618
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
619
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
620
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
621
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
622
+
596
623
  output_df = self._batch_inference(
597
624
  dataset=dataset,
598
625
  inference_method="transform",
@@ -608,8 +635,8 @@ class FeatureAgglomeration(BaseTransformer):
608
635
 
609
636
  return output_df
610
637
 
611
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
612
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
638
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
639
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
613
640
  """ Fit and return the result of each sample's clustering assignment
614
641
  For more details on this function, see [sklearn.cluster.FeatureAgglomeration.fit_predict]
615
642
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.FeatureAgglomeration.html#sklearn.cluster.FeatureAgglomeration.fit_predict)
@@ -624,13 +651,21 @@ class FeatureAgglomeration(BaseTransformer):
624
651
  Returns:
625
652
  Predicted dataset.
626
653
  """
627
- if True:
628
- self.fit(dataset)
629
- assert self._sklearn_object is not None
630
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
631
- return labels
632
- else:
633
- raise NotImplementedError
654
+ self.fit(dataset)
655
+ assert self._sklearn_object is not None
656
+ return self._sklearn_object.labels_
657
+
658
+
659
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
660
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
661
+ """
662
+ Returns:
663
+ Transformed dataset.
664
+ """
665
+ self.fit(dataset)
666
+ assert self._sklearn_object is not None
667
+ return self._sklearn_object.embedding_
668
+
634
669
 
635
670
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
636
671
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.