snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class KNeighborsClassifier(BaseTransformer):
58
70
  r"""Classifier implementing the k-nearest neighbors vote
59
71
  For more details on this class, see [sklearn.neighbors.KNeighborsClassifier]
@@ -198,7 +210,9 @@ class KNeighborsClassifier(BaseTransformer):
198
210
  self.set_label_cols(label_cols)
199
211
  self.set_passthrough_cols(passthrough_cols)
200
212
  self.set_drop_input_cols(drop_input_cols)
201
- self.set_sample_weight_col(sample_weight_col)
213
+ self.set_sample_weight_col(sample_weight_col)
214
+ self._use_external_memory_version = False
215
+ self._batch_size = -1
202
216
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
203
217
 
204
218
  self._deps = list(deps)
@@ -281,11 +295,6 @@ class KNeighborsClassifier(BaseTransformer):
281
295
  if isinstance(dataset, DataFrame):
282
296
  session = dataset._session
283
297
  assert session is not None # keep mypy happy
284
- # Validate that key package version in user workspace are supported in snowflake conda channel
285
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
286
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
287
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
288
-
289
298
  # Specify input columns so column pruning will be enforced
290
299
  selected_cols = self._get_active_columns()
291
300
  if len(selected_cols) > 0:
@@ -313,7 +322,9 @@ class KNeighborsClassifier(BaseTransformer):
313
322
  label_cols=self.label_cols,
314
323
  sample_weight_col=self.sample_weight_col,
315
324
  autogenerated=self._autogenerated,
316
- subproject=_SUBPROJECT
325
+ subproject=_SUBPROJECT,
326
+ use_external_memory_version=self._use_external_memory_version,
327
+ batch_size=self._batch_size,
317
328
  )
318
329
  self._sklearn_object = model_trainer.train()
319
330
  self._is_fitted = True
@@ -584,6 +595,22 @@ class KNeighborsClassifier(BaseTransformer):
584
595
  # each row containing a list of values.
585
596
  expected_dtype = "ARRAY"
586
597
 
598
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
599
+ if expected_dtype == "":
600
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
601
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
602
+ expected_dtype = "ARRAY"
603
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
604
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
605
+ expected_dtype = "ARRAY"
606
+ else:
607
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
608
+ # We can only infer the output types from the input types if the following two statemetns are true:
609
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
610
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
611
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
612
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
613
+
587
614
  output_df = self._batch_inference(
588
615
  dataset=dataset,
589
616
  inference_method="transform",
@@ -599,8 +626,8 @@ class KNeighborsClassifier(BaseTransformer):
599
626
 
600
627
  return output_df
601
628
 
602
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
603
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
629
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
630
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
604
631
  """ Method not supported for this class.
605
632
 
606
633
 
@@ -613,13 +640,21 @@ class KNeighborsClassifier(BaseTransformer):
613
640
  Returns:
614
641
  Predicted dataset.
615
642
  """
616
- if False:
617
- self.fit(dataset)
618
- assert self._sklearn_object is not None
619
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
620
- return labels
621
- else:
622
- raise NotImplementedError
643
+ self.fit(dataset)
644
+ assert self._sklearn_object is not None
645
+ return self._sklearn_object.labels_
646
+
647
+
648
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
649
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
650
+ """
651
+ Returns:
652
+ Transformed dataset.
653
+ """
654
+ self.fit(dataset)
655
+ assert self._sklearn_object is not None
656
+ return self._sklearn_object.embedding_
657
+
623
658
 
624
659
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
625
660
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class KNeighborsRegressor(BaseTransformer):
58
70
  r"""Regression based on k-nearest neighbors
59
71
  For more details on this class, see [sklearn.neighbors.KNeighborsRegressor]
@@ -200,7 +212,9 @@ class KNeighborsRegressor(BaseTransformer):
200
212
  self.set_label_cols(label_cols)
201
213
  self.set_passthrough_cols(passthrough_cols)
202
214
  self.set_drop_input_cols(drop_input_cols)
203
- self.set_sample_weight_col(sample_weight_col)
215
+ self.set_sample_weight_col(sample_weight_col)
216
+ self._use_external_memory_version = False
217
+ self._batch_size = -1
204
218
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
205
219
 
206
220
  self._deps = list(deps)
@@ -283,11 +297,6 @@ class KNeighborsRegressor(BaseTransformer):
283
297
  if isinstance(dataset, DataFrame):
284
298
  session = dataset._session
285
299
  assert session is not None # keep mypy happy
286
- # Validate that key package version in user workspace are supported in snowflake conda channel
287
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
288
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
289
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
290
-
291
300
  # Specify input columns so column pruning will be enforced
292
301
  selected_cols = self._get_active_columns()
293
302
  if len(selected_cols) > 0:
@@ -315,7 +324,9 @@ class KNeighborsRegressor(BaseTransformer):
315
324
  label_cols=self.label_cols,
316
325
  sample_weight_col=self.sample_weight_col,
317
326
  autogenerated=self._autogenerated,
318
- subproject=_SUBPROJECT
327
+ subproject=_SUBPROJECT,
328
+ use_external_memory_version=self._use_external_memory_version,
329
+ batch_size=self._batch_size,
319
330
  )
320
331
  self._sklearn_object = model_trainer.train()
321
332
  self._is_fitted = True
@@ -586,6 +597,22 @@ class KNeighborsRegressor(BaseTransformer):
586
597
  # each row containing a list of values.
587
598
  expected_dtype = "ARRAY"
588
599
 
600
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
601
+ if expected_dtype == "":
602
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
603
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
604
+ expected_dtype = "ARRAY"
605
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
606
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
607
+ expected_dtype = "ARRAY"
608
+ else:
609
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
610
+ # We can only infer the output types from the input types if the following two statemetns are true:
611
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
612
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
613
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
614
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
615
+
589
616
  output_df = self._batch_inference(
590
617
  dataset=dataset,
591
618
  inference_method="transform",
@@ -601,8 +628,8 @@ class KNeighborsRegressor(BaseTransformer):
601
628
 
602
629
  return output_df
603
630
 
604
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
605
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
631
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
632
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
606
633
  """ Method not supported for this class.
607
634
 
608
635
 
@@ -615,13 +642,21 @@ class KNeighborsRegressor(BaseTransformer):
615
642
  Returns:
616
643
  Predicted dataset.
617
644
  """
618
- if False:
619
- self.fit(dataset)
620
- assert self._sklearn_object is not None
621
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
622
- return labels
623
- else:
624
- raise NotImplementedError
645
+ self.fit(dataset)
646
+ assert self._sklearn_object is not None
647
+ return self._sklearn_object.labels_
648
+
649
+
650
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
651
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
652
+ """
653
+ Returns:
654
+ Transformed dataset.
655
+ """
656
+ self.fit(dataset)
657
+ assert self._sklearn_object is not None
658
+ return self._sklearn_object.embedding_
659
+
625
660
 
626
661
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
627
662
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class KernelDensity(BaseTransformer):
58
70
  r"""Kernel Density Estimation
59
71
  For more details on this class, see [sklearn.neighbors.KernelDensity]
@@ -176,7 +188,9 @@ class KernelDensity(BaseTransformer):
176
188
  self.set_label_cols(label_cols)
177
189
  self.set_passthrough_cols(passthrough_cols)
178
190
  self.set_drop_input_cols(drop_input_cols)
179
- self.set_sample_weight_col(sample_weight_col)
191
+ self.set_sample_weight_col(sample_weight_col)
192
+ self._use_external_memory_version = False
193
+ self._batch_size = -1
180
194
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
181
195
 
182
196
  self._deps = list(deps)
@@ -260,11 +274,6 @@ class KernelDensity(BaseTransformer):
260
274
  if isinstance(dataset, DataFrame):
261
275
  session = dataset._session
262
276
  assert session is not None # keep mypy happy
263
- # Validate that key package version in user workspace are supported in snowflake conda channel
264
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
265
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
266
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
267
-
268
277
  # Specify input columns so column pruning will be enforced
269
278
  selected_cols = self._get_active_columns()
270
279
  if len(selected_cols) > 0:
@@ -292,7 +301,9 @@ class KernelDensity(BaseTransformer):
292
301
  label_cols=self.label_cols,
293
302
  sample_weight_col=self.sample_weight_col,
294
303
  autogenerated=self._autogenerated,
295
- subproject=_SUBPROJECT
304
+ subproject=_SUBPROJECT,
305
+ use_external_memory_version=self._use_external_memory_version,
306
+ batch_size=self._batch_size,
296
307
  )
297
308
  self._sklearn_object = model_trainer.train()
298
309
  self._is_fitted = True
@@ -561,6 +572,22 @@ class KernelDensity(BaseTransformer):
561
572
  # each row containing a list of values.
562
573
  expected_dtype = "ARRAY"
563
574
 
575
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
576
+ if expected_dtype == "":
577
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
578
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
579
+ expected_dtype = "ARRAY"
580
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
581
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
582
+ expected_dtype = "ARRAY"
583
+ else:
584
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
585
+ # We can only infer the output types from the input types if the following two statemetns are true:
586
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
587
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
588
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
589
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
590
+
564
591
  output_df = self._batch_inference(
565
592
  dataset=dataset,
566
593
  inference_method="transform",
@@ -576,8 +603,8 @@ class KernelDensity(BaseTransformer):
576
603
 
577
604
  return output_df
578
605
 
579
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
580
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
606
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
607
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
581
608
  """ Method not supported for this class.
582
609
 
583
610
 
@@ -590,13 +617,21 @@ class KernelDensity(BaseTransformer):
590
617
  Returns:
591
618
  Predicted dataset.
592
619
  """
593
- if False:
594
- self.fit(dataset)
595
- assert self._sklearn_object is not None
596
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
597
- return labels
598
- else:
599
- raise NotImplementedError
620
+ self.fit(dataset)
621
+ assert self._sklearn_object is not None
622
+ return self._sklearn_object.labels_
623
+
624
+
625
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
626
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
627
+ """
628
+ Returns:
629
+ Transformed dataset.
630
+ """
631
+ self.fit(dataset)
632
+ assert self._sklearn_object is not None
633
+ return self._sklearn_object.embedding_
634
+
600
635
 
601
636
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
602
637
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LocalOutlierFactor(BaseTransformer):
58
70
  r"""Unsupervised Outlier Detection using the Local Outlier Factor (LOF)
59
71
  For more details on this class, see [sklearn.neighbors.LocalOutlierFactor]
@@ -204,7 +216,9 @@ class LocalOutlierFactor(BaseTransformer):
204
216
  self.set_label_cols(label_cols)
205
217
  self.set_passthrough_cols(passthrough_cols)
206
218
  self.set_drop_input_cols(drop_input_cols)
207
- self.set_sample_weight_col(sample_weight_col)
219
+ self.set_sample_weight_col(sample_weight_col)
220
+ self._use_external_memory_version = False
221
+ self._batch_size = -1
208
222
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
209
223
 
210
224
  self._deps = list(deps)
@@ -288,11 +302,6 @@ class LocalOutlierFactor(BaseTransformer):
288
302
  if isinstance(dataset, DataFrame):
289
303
  session = dataset._session
290
304
  assert session is not None # keep mypy happy
291
- # Validate that key package version in user workspace are supported in snowflake conda channel
292
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
293
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
294
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
295
-
296
305
  # Specify input columns so column pruning will be enforced
297
306
  selected_cols = self._get_active_columns()
298
307
  if len(selected_cols) > 0:
@@ -320,7 +329,9 @@ class LocalOutlierFactor(BaseTransformer):
320
329
  label_cols=self.label_cols,
321
330
  sample_weight_col=self.sample_weight_col,
322
331
  autogenerated=self._autogenerated,
323
- subproject=_SUBPROJECT
332
+ subproject=_SUBPROJECT,
333
+ use_external_memory_version=self._use_external_memory_version,
334
+ batch_size=self._batch_size,
324
335
  )
325
336
  self._sklearn_object = model_trainer.train()
326
337
  self._is_fitted = True
@@ -591,6 +602,22 @@ class LocalOutlierFactor(BaseTransformer):
591
602
  # each row containing a list of values.
592
603
  expected_dtype = "ARRAY"
593
604
 
605
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
606
+ if expected_dtype == "":
607
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
608
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
609
+ expected_dtype = "ARRAY"
610
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
611
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
612
+ expected_dtype = "ARRAY"
613
+ else:
614
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
615
+ # We can only infer the output types from the input types if the following two statemetns are true:
616
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
617
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
618
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
619
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
620
+
594
621
  output_df = self._batch_inference(
595
622
  dataset=dataset,
596
623
  inference_method="transform",
@@ -606,8 +633,8 @@ class LocalOutlierFactor(BaseTransformer):
606
633
 
607
634
  return output_df
608
635
 
609
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
610
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
636
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
637
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
611
638
  """ Fit the model to the training set X and return the labels
612
639
  For more details on this function, see [sklearn.neighbors.LocalOutlierFactor.fit_predict]
613
640
  (https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.LocalOutlierFactor.html#sklearn.neighbors.LocalOutlierFactor.fit_predict)
@@ -622,13 +649,21 @@ class LocalOutlierFactor(BaseTransformer):
622
649
  Returns:
623
650
  Predicted dataset.
624
651
  """
625
- if False:
626
- self.fit(dataset)
627
- assert self._sklearn_object is not None
628
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
629
- return labels
630
- else:
631
- raise NotImplementedError
652
+ self.fit(dataset)
653
+ assert self._sklearn_object is not None
654
+ return self._sklearn_object.labels_
655
+
656
+
657
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
658
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
659
+ """
660
+ Returns:
661
+ Transformed dataset.
662
+ """
663
+ self.fit(dataset)
664
+ assert self._sklearn_object is not None
665
+ return self._sklearn_object.embedding_
666
+
632
667
 
633
668
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
634
669
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class NearestCentroid(BaseTransformer):
58
70
  r"""Nearest centroid classifier
59
71
  For more details on this class, see [sklearn.neighbors.NearestCentroid]
@@ -144,7 +156,9 @@ class NearestCentroid(BaseTransformer):
144
156
  self.set_label_cols(label_cols)
145
157
  self.set_passthrough_cols(passthrough_cols)
146
158
  self.set_drop_input_cols(drop_input_cols)
147
- self.set_sample_weight_col(sample_weight_col)
159
+ self.set_sample_weight_col(sample_weight_col)
160
+ self._use_external_memory_version = False
161
+ self._batch_size = -1
148
162
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
149
163
 
150
164
  self._deps = list(deps)
@@ -221,11 +235,6 @@ class NearestCentroid(BaseTransformer):
221
235
  if isinstance(dataset, DataFrame):
222
236
  session = dataset._session
223
237
  assert session is not None # keep mypy happy
224
- # Validate that key package version in user workspace are supported in snowflake conda channel
225
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
226
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
227
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
228
-
229
238
  # Specify input columns so column pruning will be enforced
230
239
  selected_cols = self._get_active_columns()
231
240
  if len(selected_cols) > 0:
@@ -253,7 +262,9 @@ class NearestCentroid(BaseTransformer):
253
262
  label_cols=self.label_cols,
254
263
  sample_weight_col=self.sample_weight_col,
255
264
  autogenerated=self._autogenerated,
256
- subproject=_SUBPROJECT
265
+ subproject=_SUBPROJECT,
266
+ use_external_memory_version=self._use_external_memory_version,
267
+ batch_size=self._batch_size,
257
268
  )
258
269
  self._sklearn_object = model_trainer.train()
259
270
  self._is_fitted = True
@@ -524,6 +535,22 @@ class NearestCentroid(BaseTransformer):
524
535
  # each row containing a list of values.
525
536
  expected_dtype = "ARRAY"
526
537
 
538
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
539
+ if expected_dtype == "":
540
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
541
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
542
+ expected_dtype = "ARRAY"
543
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
544
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
545
+ expected_dtype = "ARRAY"
546
+ else:
547
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
548
+ # We can only infer the output types from the input types if the following two statemetns are true:
549
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
550
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
551
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
552
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
553
+
527
554
  output_df = self._batch_inference(
528
555
  dataset=dataset,
529
556
  inference_method="transform",
@@ -539,8 +566,8 @@ class NearestCentroid(BaseTransformer):
539
566
 
540
567
  return output_df
541
568
 
542
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
543
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
569
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
570
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
544
571
  """ Method not supported for this class.
545
572
 
546
573
 
@@ -553,13 +580,21 @@ class NearestCentroid(BaseTransformer):
553
580
  Returns:
554
581
  Predicted dataset.
555
582
  """
556
- if False:
557
- self.fit(dataset)
558
- assert self._sklearn_object is not None
559
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
560
- return labels
561
- else:
562
- raise NotImplementedError
583
+ self.fit(dataset)
584
+ assert self._sklearn_object is not None
585
+ return self._sklearn_object.labels_
586
+
587
+
588
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
589
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
590
+ """
591
+ Returns:
592
+ Transformed dataset.
593
+ """
594
+ self.fit(dataset)
595
+ assert self._sklearn_object is not None
596
+ return self._sklearn_object.embedding_
597
+
563
598
 
564
599
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
565
600
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.