snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class NearestNeighbors(BaseTransformer):
58
70
  r"""Unsupervised learner for implementing neighbor searches
59
71
  For more details on this class, see [sklearn.neighbors.NearestNeighbors]
@@ -188,7 +200,9 @@ class NearestNeighbors(BaseTransformer):
188
200
  self.set_label_cols(label_cols)
189
201
  self.set_passthrough_cols(passthrough_cols)
190
202
  self.set_drop_input_cols(drop_input_cols)
191
- self.set_sample_weight_col(sample_weight_col)
203
+ self.set_sample_weight_col(sample_weight_col)
204
+ self._use_external_memory_version = False
205
+ self._batch_size = -1
192
206
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
193
207
 
194
208
  self._deps = list(deps)
@@ -271,11 +285,6 @@ class NearestNeighbors(BaseTransformer):
271
285
  if isinstance(dataset, DataFrame):
272
286
  session = dataset._session
273
287
  assert session is not None # keep mypy happy
274
- # Validate that key package version in user workspace are supported in snowflake conda channel
275
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
276
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
277
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
278
-
279
288
  # Specify input columns so column pruning will be enforced
280
289
  selected_cols = self._get_active_columns()
281
290
  if len(selected_cols) > 0:
@@ -303,7 +312,9 @@ class NearestNeighbors(BaseTransformer):
303
312
  label_cols=self.label_cols,
304
313
  sample_weight_col=self.sample_weight_col,
305
314
  autogenerated=self._autogenerated,
306
- subproject=_SUBPROJECT
315
+ subproject=_SUBPROJECT,
316
+ use_external_memory_version=self._use_external_memory_version,
317
+ batch_size=self._batch_size,
307
318
  )
308
319
  self._sklearn_object = model_trainer.train()
309
320
  self._is_fitted = True
@@ -572,6 +583,22 @@ class NearestNeighbors(BaseTransformer):
572
583
  # each row containing a list of values.
573
584
  expected_dtype = "ARRAY"
574
585
 
586
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
587
+ if expected_dtype == "":
588
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
589
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
590
+ expected_dtype = "ARRAY"
591
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
592
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
593
+ expected_dtype = "ARRAY"
594
+ else:
595
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
596
+ # We can only infer the output types from the input types if the following two statemetns are true:
597
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
598
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
599
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
600
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
601
+
575
602
  output_df = self._batch_inference(
576
603
  dataset=dataset,
577
604
  inference_method="transform",
@@ -587,8 +614,8 @@ class NearestNeighbors(BaseTransformer):
587
614
 
588
615
  return output_df
589
616
 
590
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
591
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
617
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
618
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
592
619
  """ Method not supported for this class.
593
620
 
594
621
 
@@ -601,13 +628,21 @@ class NearestNeighbors(BaseTransformer):
601
628
  Returns:
602
629
  Predicted dataset.
603
630
  """
604
- if False:
605
- self.fit(dataset)
606
- assert self._sklearn_object is not None
607
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
608
- return labels
609
- else:
610
- raise NotImplementedError
631
+ self.fit(dataset)
632
+ assert self._sklearn_object is not None
633
+ return self._sklearn_object.labels_
634
+
635
+
636
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
637
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
638
+ """
639
+ Returns:
640
+ Transformed dataset.
641
+ """
642
+ self.fit(dataset)
643
+ assert self._sklearn_object is not None
644
+ return self._sklearn_object.embedding_
645
+
611
646
 
612
647
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
613
648
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class NeighborhoodComponentsAnalysis(BaseTransformer):
58
70
  r"""Neighborhood Components Analysis
59
71
  For more details on this class, see [sklearn.neighbors.NeighborhoodComponentsAnalysis]
@@ -209,7 +221,9 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
209
221
  self.set_label_cols(label_cols)
210
222
  self.set_passthrough_cols(passthrough_cols)
211
223
  self.set_drop_input_cols(drop_input_cols)
212
- self.set_sample_weight_col(sample_weight_col)
224
+ self.set_sample_weight_col(sample_weight_col)
225
+ self._use_external_memory_version = False
226
+ self._batch_size = -1
213
227
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
214
228
 
215
229
  self._deps = list(deps)
@@ -292,11 +306,6 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
292
306
  if isinstance(dataset, DataFrame):
293
307
  session = dataset._session
294
308
  assert session is not None # keep mypy happy
295
- # Validate that key package version in user workspace are supported in snowflake conda channel
296
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
297
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
298
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
299
-
300
309
  # Specify input columns so column pruning will be enforced
301
310
  selected_cols = self._get_active_columns()
302
311
  if len(selected_cols) > 0:
@@ -324,7 +333,9 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
324
333
  label_cols=self.label_cols,
325
334
  sample_weight_col=self.sample_weight_col,
326
335
  autogenerated=self._autogenerated,
327
- subproject=_SUBPROJECT
336
+ subproject=_SUBPROJECT,
337
+ use_external_memory_version=self._use_external_memory_version,
338
+ batch_size=self._batch_size,
328
339
  )
329
340
  self._sklearn_object = model_trainer.train()
330
341
  self._is_fitted = True
@@ -595,6 +606,22 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
595
606
  # each row containing a list of values.
596
607
  expected_dtype = "ARRAY"
597
608
 
609
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
610
+ if expected_dtype == "":
611
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
612
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
613
+ expected_dtype = "ARRAY"
614
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
615
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
616
+ expected_dtype = "ARRAY"
617
+ else:
618
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
619
+ # We can only infer the output types from the input types if the following two statemetns are true:
620
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
621
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
622
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
623
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
624
+
598
625
  output_df = self._batch_inference(
599
626
  dataset=dataset,
600
627
  inference_method="transform",
@@ -610,8 +637,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
610
637
 
611
638
  return output_df
612
639
 
613
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
614
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
640
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
641
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
615
642
  """ Method not supported for this class.
616
643
 
617
644
 
@@ -624,13 +651,21 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
624
651
  Returns:
625
652
  Predicted dataset.
626
653
  """
627
- if False:
628
- self.fit(dataset)
629
- assert self._sklearn_object is not None
630
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
631
- return labels
632
- else:
633
- raise NotImplementedError
654
+ self.fit(dataset)
655
+ assert self._sklearn_object is not None
656
+ return self._sklearn_object.labels_
657
+
658
+
659
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
660
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
661
+ """
662
+ Returns:
663
+ Transformed dataset.
664
+ """
665
+ self.fit(dataset)
666
+ assert self._sklearn_object is not None
667
+ return self._sklearn_object.embedding_
668
+
634
669
 
635
670
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
636
671
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RadiusNeighborsClassifier(BaseTransformer):
58
70
  r"""Classifier implementing a vote among neighbors within a given radius
59
71
  For more details on this class, see [sklearn.neighbors.RadiusNeighborsClassifier]
@@ -209,7 +221,9 @@ class RadiusNeighborsClassifier(BaseTransformer):
209
221
  self.set_label_cols(label_cols)
210
222
  self.set_passthrough_cols(passthrough_cols)
211
223
  self.set_drop_input_cols(drop_input_cols)
212
- self.set_sample_weight_col(sample_weight_col)
224
+ self.set_sample_weight_col(sample_weight_col)
225
+ self._use_external_memory_version = False
226
+ self._batch_size = -1
213
227
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
214
228
 
215
229
  self._deps = list(deps)
@@ -293,11 +307,6 @@ class RadiusNeighborsClassifier(BaseTransformer):
293
307
  if isinstance(dataset, DataFrame):
294
308
  session = dataset._session
295
309
  assert session is not None # keep mypy happy
296
- # Validate that key package version in user workspace are supported in snowflake conda channel
297
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
298
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
299
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
300
-
301
310
  # Specify input columns so column pruning will be enforced
302
311
  selected_cols = self._get_active_columns()
303
312
  if len(selected_cols) > 0:
@@ -325,7 +334,9 @@ class RadiusNeighborsClassifier(BaseTransformer):
325
334
  label_cols=self.label_cols,
326
335
  sample_weight_col=self.sample_weight_col,
327
336
  autogenerated=self._autogenerated,
328
- subproject=_SUBPROJECT
337
+ subproject=_SUBPROJECT,
338
+ use_external_memory_version=self._use_external_memory_version,
339
+ batch_size=self._batch_size,
329
340
  )
330
341
  self._sklearn_object = model_trainer.train()
331
342
  self._is_fitted = True
@@ -596,6 +607,22 @@ class RadiusNeighborsClassifier(BaseTransformer):
596
607
  # each row containing a list of values.
597
608
  expected_dtype = "ARRAY"
598
609
 
610
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
611
+ if expected_dtype == "":
612
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
613
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
614
+ expected_dtype = "ARRAY"
615
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
616
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
617
+ expected_dtype = "ARRAY"
618
+ else:
619
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
620
+ # We can only infer the output types from the input types if the following two statemetns are true:
621
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
622
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
623
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
624
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
625
+
599
626
  output_df = self._batch_inference(
600
627
  dataset=dataset,
601
628
  inference_method="transform",
@@ -611,8 +638,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
611
638
 
612
639
  return output_df
613
640
 
614
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
615
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
641
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
642
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
616
643
  """ Method not supported for this class.
617
644
 
618
645
 
@@ -625,13 +652,21 @@ class RadiusNeighborsClassifier(BaseTransformer):
625
652
  Returns:
626
653
  Predicted dataset.
627
654
  """
628
- if False:
629
- self.fit(dataset)
630
- assert self._sklearn_object is not None
631
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
632
- return labels
633
- else:
634
- raise NotImplementedError
655
+ self.fit(dataset)
656
+ assert self._sklearn_object is not None
657
+ return self._sklearn_object.labels_
658
+
659
+
660
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
661
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
662
+ """
663
+ Returns:
664
+ Transformed dataset.
665
+ """
666
+ self.fit(dataset)
667
+ assert self._sklearn_object is not None
668
+ return self._sklearn_object.embedding_
669
+
635
670
 
636
671
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
637
672
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RadiusNeighborsRegressor(BaseTransformer):
58
70
  r"""Regression based on neighbors within a fixed radius
59
71
  For more details on this class, see [sklearn.neighbors.RadiusNeighborsRegressor]
@@ -200,7 +212,9 @@ class RadiusNeighborsRegressor(BaseTransformer):
200
212
  self.set_label_cols(label_cols)
201
213
  self.set_passthrough_cols(passthrough_cols)
202
214
  self.set_drop_input_cols(drop_input_cols)
203
- self.set_sample_weight_col(sample_weight_col)
215
+ self.set_sample_weight_col(sample_weight_col)
216
+ self._use_external_memory_version = False
217
+ self._batch_size = -1
204
218
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
205
219
 
206
220
  self._deps = list(deps)
@@ -283,11 +297,6 @@ class RadiusNeighborsRegressor(BaseTransformer):
283
297
  if isinstance(dataset, DataFrame):
284
298
  session = dataset._session
285
299
  assert session is not None # keep mypy happy
286
- # Validate that key package version in user workspace are supported in snowflake conda channel
287
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
288
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
289
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
290
-
291
300
  # Specify input columns so column pruning will be enforced
292
301
  selected_cols = self._get_active_columns()
293
302
  if len(selected_cols) > 0:
@@ -315,7 +324,9 @@ class RadiusNeighborsRegressor(BaseTransformer):
315
324
  label_cols=self.label_cols,
316
325
  sample_weight_col=self.sample_weight_col,
317
326
  autogenerated=self._autogenerated,
318
- subproject=_SUBPROJECT
327
+ subproject=_SUBPROJECT,
328
+ use_external_memory_version=self._use_external_memory_version,
329
+ batch_size=self._batch_size,
319
330
  )
320
331
  self._sklearn_object = model_trainer.train()
321
332
  self._is_fitted = True
@@ -586,6 +597,22 @@ class RadiusNeighborsRegressor(BaseTransformer):
586
597
  # each row containing a list of values.
587
598
  expected_dtype = "ARRAY"
588
599
 
600
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
601
+ if expected_dtype == "":
602
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
603
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
604
+ expected_dtype = "ARRAY"
605
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
606
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
607
+ expected_dtype = "ARRAY"
608
+ else:
609
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
610
+ # We can only infer the output types from the input types if the following two statemetns are true:
611
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
612
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
613
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
614
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
615
+
589
616
  output_df = self._batch_inference(
590
617
  dataset=dataset,
591
618
  inference_method="transform",
@@ -601,8 +628,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
601
628
 
602
629
  return output_df
603
630
 
604
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
605
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
631
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
632
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
606
633
  """ Method not supported for this class.
607
634
 
608
635
 
@@ -615,13 +642,21 @@ class RadiusNeighborsRegressor(BaseTransformer):
615
642
  Returns:
616
643
  Predicted dataset.
617
644
  """
618
- if False:
619
- self.fit(dataset)
620
- assert self._sklearn_object is not None
621
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
622
- return labels
623
- else:
624
- raise NotImplementedError
645
+ self.fit(dataset)
646
+ assert self._sklearn_object is not None
647
+ return self._sklearn_object.labels_
648
+
649
+
650
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
651
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
652
+ """
653
+ Returns:
654
+ Transformed dataset.
655
+ """
656
+ self.fit(dataset)
657
+ assert self._sklearn_object is not None
658
+ return self._sklearn_object.embedding_
659
+
625
660
 
626
661
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
627
662
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class BernoulliRBM(BaseTransformer):
58
70
  r"""Bernoulli Restricted Boltzmann Machine (RBM)
59
71
  For more details on this class, see [sklearn.neural_network.BernoulliRBM]
@@ -159,7 +171,9 @@ class BernoulliRBM(BaseTransformer):
159
171
  self.set_label_cols(label_cols)
160
172
  self.set_passthrough_cols(passthrough_cols)
161
173
  self.set_drop_input_cols(drop_input_cols)
162
- self.set_sample_weight_col(sample_weight_col)
174
+ self.set_sample_weight_col(sample_weight_col)
175
+ self._use_external_memory_version = False
176
+ self._batch_size = -1
163
177
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
164
178
 
165
179
  self._deps = list(deps)
@@ -240,11 +254,6 @@ class BernoulliRBM(BaseTransformer):
240
254
  if isinstance(dataset, DataFrame):
241
255
  session = dataset._session
242
256
  assert session is not None # keep mypy happy
243
- # Validate that key package version in user workspace are supported in snowflake conda channel
244
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
245
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
246
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
247
-
248
257
  # Specify input columns so column pruning will be enforced
249
258
  selected_cols = self._get_active_columns()
250
259
  if len(selected_cols) > 0:
@@ -272,7 +281,9 @@ class BernoulliRBM(BaseTransformer):
272
281
  label_cols=self.label_cols,
273
282
  sample_weight_col=self.sample_weight_col,
274
283
  autogenerated=self._autogenerated,
275
- subproject=_SUBPROJECT
284
+ subproject=_SUBPROJECT,
285
+ use_external_memory_version=self._use_external_memory_version,
286
+ batch_size=self._batch_size,
276
287
  )
277
288
  self._sklearn_object = model_trainer.train()
278
289
  self._is_fitted = True
@@ -543,6 +554,22 @@ class BernoulliRBM(BaseTransformer):
543
554
  # each row containing a list of values.
544
555
  expected_dtype = "ARRAY"
545
556
 
557
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
558
+ if expected_dtype == "":
559
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
560
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
561
+ expected_dtype = "ARRAY"
562
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
563
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ else:
566
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
567
+ # We can only infer the output types from the input types if the following two statemetns are true:
568
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
569
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
570
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
571
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
572
+
546
573
  output_df = self._batch_inference(
547
574
  dataset=dataset,
548
575
  inference_method="transform",
@@ -558,8 +585,8 @@ class BernoulliRBM(BaseTransformer):
558
585
 
559
586
  return output_df
560
587
 
561
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
562
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
588
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
589
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
563
590
  """ Method not supported for this class.
564
591
 
565
592
 
@@ -572,13 +599,21 @@ class BernoulliRBM(BaseTransformer):
572
599
  Returns:
573
600
  Predicted dataset.
574
601
  """
575
- if False:
576
- self.fit(dataset)
577
- assert self._sklearn_object is not None
578
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
579
- return labels
580
- else:
581
- raise NotImplementedError
602
+ self.fit(dataset)
603
+ assert self._sklearn_object is not None
604
+ return self._sklearn_object.labels_
605
+
606
+
607
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
608
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
609
+ """
610
+ Returns:
611
+ Transformed dataset.
612
+ """
613
+ self.fit(dataset)
614
+ assert self._sklearn_object is not None
615
+ return self._sklearn_object.embedding_
616
+
582
617
 
583
618
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
584
619
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.