snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class IncrementalPCA(BaseTransformer):
58
70
  r"""Incremental principal components analysis (IPCA)
59
71
  For more details on this class, see [sklearn.decomposition.IncrementalPCA]
@@ -150,7 +162,9 @@ class IncrementalPCA(BaseTransformer):
150
162
  self.set_label_cols(label_cols)
151
163
  self.set_passthrough_cols(passthrough_cols)
152
164
  self.set_drop_input_cols(drop_input_cols)
153
- self.set_sample_weight_col(sample_weight_col)
165
+ self.set_sample_weight_col(sample_weight_col)
166
+ self._use_external_memory_version = False
167
+ self._batch_size = -1
154
168
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
155
169
 
156
170
  self._deps = list(deps)
@@ -229,11 +243,6 @@ class IncrementalPCA(BaseTransformer):
229
243
  if isinstance(dataset, DataFrame):
230
244
  session = dataset._session
231
245
  assert session is not None # keep mypy happy
232
- # Validate that key package version in user workspace are supported in snowflake conda channel
233
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
234
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
235
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
236
-
237
246
  # Specify input columns so column pruning will be enforced
238
247
  selected_cols = self._get_active_columns()
239
248
  if len(selected_cols) > 0:
@@ -261,7 +270,9 @@ class IncrementalPCA(BaseTransformer):
261
270
  label_cols=self.label_cols,
262
271
  sample_weight_col=self.sample_weight_col,
263
272
  autogenerated=self._autogenerated,
264
- subproject=_SUBPROJECT
273
+ subproject=_SUBPROJECT,
274
+ use_external_memory_version=self._use_external_memory_version,
275
+ batch_size=self._batch_size,
265
276
  )
266
277
  self._sklearn_object = model_trainer.train()
267
278
  self._is_fitted = True
@@ -532,6 +543,22 @@ class IncrementalPCA(BaseTransformer):
532
543
  # each row containing a list of values.
533
544
  expected_dtype = "ARRAY"
534
545
 
546
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
547
+ if expected_dtype == "":
548
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
549
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
550
+ expected_dtype = "ARRAY"
551
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
552
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
553
+ expected_dtype = "ARRAY"
554
+ else:
555
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
556
+ # We can only infer the output types from the input types if the following two statemetns are true:
557
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
558
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
559
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
560
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
561
+
535
562
  output_df = self._batch_inference(
536
563
  dataset=dataset,
537
564
  inference_method="transform",
@@ -547,8 +574,8 @@ class IncrementalPCA(BaseTransformer):
547
574
 
548
575
  return output_df
549
576
 
550
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
551
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
577
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
578
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
552
579
  """ Method not supported for this class.
553
580
 
554
581
 
@@ -561,13 +588,21 @@ class IncrementalPCA(BaseTransformer):
561
588
  Returns:
562
589
  Predicted dataset.
563
590
  """
564
- if False:
565
- self.fit(dataset)
566
- assert self._sklearn_object is not None
567
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
568
- return labels
569
- else:
570
- raise NotImplementedError
591
+ self.fit(dataset)
592
+ assert self._sklearn_object is not None
593
+ return self._sklearn_object.labels_
594
+
595
+
596
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
597
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
598
+ """
599
+ Returns:
600
+ Transformed dataset.
601
+ """
602
+ self.fit(dataset)
603
+ assert self._sklearn_object is not None
604
+ return self._sklearn_object.embedding_
605
+
571
606
 
572
607
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
573
608
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class KernelPCA(BaseTransformer):
58
70
  r"""Kernel Principal component analysis (KPCA) [1]_
59
71
  For more details on this class, see [sklearn.decomposition.KernelPCA]
@@ -234,7 +246,9 @@ class KernelPCA(BaseTransformer):
234
246
  self.set_label_cols(label_cols)
235
247
  self.set_passthrough_cols(passthrough_cols)
236
248
  self.set_drop_input_cols(drop_input_cols)
237
- self.set_sample_weight_col(sample_weight_col)
249
+ self.set_sample_weight_col(sample_weight_col)
250
+ self._use_external_memory_version = False
251
+ self._batch_size = -1
238
252
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
239
253
 
240
254
  self._deps = list(deps)
@@ -325,11 +339,6 @@ class KernelPCA(BaseTransformer):
325
339
  if isinstance(dataset, DataFrame):
326
340
  session = dataset._session
327
341
  assert session is not None # keep mypy happy
328
- # Validate that key package version in user workspace are supported in snowflake conda channel
329
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
330
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
331
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
332
-
333
342
  # Specify input columns so column pruning will be enforced
334
343
  selected_cols = self._get_active_columns()
335
344
  if len(selected_cols) > 0:
@@ -357,7 +366,9 @@ class KernelPCA(BaseTransformer):
357
366
  label_cols=self.label_cols,
358
367
  sample_weight_col=self.sample_weight_col,
359
368
  autogenerated=self._autogenerated,
360
- subproject=_SUBPROJECT
369
+ subproject=_SUBPROJECT,
370
+ use_external_memory_version=self._use_external_memory_version,
371
+ batch_size=self._batch_size,
361
372
  )
362
373
  self._sklearn_object = model_trainer.train()
363
374
  self._is_fitted = True
@@ -628,6 +639,22 @@ class KernelPCA(BaseTransformer):
628
639
  # each row containing a list of values.
629
640
  expected_dtype = "ARRAY"
630
641
 
642
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
643
+ if expected_dtype == "":
644
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
645
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
646
+ expected_dtype = "ARRAY"
647
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
648
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
649
+ expected_dtype = "ARRAY"
650
+ else:
651
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
652
+ # We can only infer the output types from the input types if the following two statemetns are true:
653
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
654
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
655
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
656
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
657
+
631
658
  output_df = self._batch_inference(
632
659
  dataset=dataset,
633
660
  inference_method="transform",
@@ -643,8 +670,8 @@ class KernelPCA(BaseTransformer):
643
670
 
644
671
  return output_df
645
672
 
646
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
647
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
673
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
674
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
648
675
  """ Method not supported for this class.
649
676
 
650
677
 
@@ -657,13 +684,21 @@ class KernelPCA(BaseTransformer):
657
684
  Returns:
658
685
  Predicted dataset.
659
686
  """
660
- if False:
661
- self.fit(dataset)
662
- assert self._sklearn_object is not None
663
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
664
- return labels
665
- else:
666
- raise NotImplementedError
687
+ self.fit(dataset)
688
+ assert self._sklearn_object is not None
689
+ return self._sklearn_object.labels_
690
+
691
+
692
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
693
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
694
+ """
695
+ Returns:
696
+ Transformed dataset.
697
+ """
698
+ self.fit(dataset)
699
+ assert self._sklearn_object is not None
700
+ return self._sklearn_object.embedding_
701
+
667
702
 
668
703
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
669
704
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MiniBatchDictionaryLearning(BaseTransformer):
58
70
  r"""Mini-batch dictionary learning
59
71
  For more details on this class, see [sklearn.decomposition.MiniBatchDictionaryLearning]
@@ -251,7 +263,9 @@ class MiniBatchDictionaryLearning(BaseTransformer):
251
263
  self.set_label_cols(label_cols)
252
264
  self.set_passthrough_cols(passthrough_cols)
253
265
  self.set_drop_input_cols(drop_input_cols)
254
- self.set_sample_weight_col(sample_weight_col)
266
+ self.set_sample_weight_col(sample_weight_col)
267
+ self._use_external_memory_version = False
268
+ self._batch_size = -1
255
269
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
256
270
 
257
271
  self._deps = list(deps)
@@ -347,11 +361,6 @@ class MiniBatchDictionaryLearning(BaseTransformer):
347
361
  if isinstance(dataset, DataFrame):
348
362
  session = dataset._session
349
363
  assert session is not None # keep mypy happy
350
- # Validate that key package version in user workspace are supported in snowflake conda channel
351
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
352
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
353
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
354
-
355
364
  # Specify input columns so column pruning will be enforced
356
365
  selected_cols = self._get_active_columns()
357
366
  if len(selected_cols) > 0:
@@ -379,7 +388,9 @@ class MiniBatchDictionaryLearning(BaseTransformer):
379
388
  label_cols=self.label_cols,
380
389
  sample_weight_col=self.sample_weight_col,
381
390
  autogenerated=self._autogenerated,
382
- subproject=_SUBPROJECT
391
+ subproject=_SUBPROJECT,
392
+ use_external_memory_version=self._use_external_memory_version,
393
+ batch_size=self._batch_size,
383
394
  )
384
395
  self._sklearn_object = model_trainer.train()
385
396
  self._is_fitted = True
@@ -650,6 +661,22 @@ class MiniBatchDictionaryLearning(BaseTransformer):
650
661
  # each row containing a list of values.
651
662
  expected_dtype = "ARRAY"
652
663
 
664
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
665
+ if expected_dtype == "":
666
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
667
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
668
+ expected_dtype = "ARRAY"
669
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
670
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
671
+ expected_dtype = "ARRAY"
672
+ else:
673
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
674
+ # We can only infer the output types from the input types if the following two statemetns are true:
675
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
676
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
677
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
678
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
679
+
653
680
  output_df = self._batch_inference(
654
681
  dataset=dataset,
655
682
  inference_method="transform",
@@ -665,8 +692,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
665
692
 
666
693
  return output_df
667
694
 
668
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
669
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
695
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
696
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
670
697
  """ Method not supported for this class.
671
698
 
672
699
 
@@ -679,13 +706,21 @@ class MiniBatchDictionaryLearning(BaseTransformer):
679
706
  Returns:
680
707
  Predicted dataset.
681
708
  """
682
- if False:
683
- self.fit(dataset)
684
- assert self._sklearn_object is not None
685
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
686
- return labels
687
- else:
688
- raise NotImplementedError
709
+ self.fit(dataset)
710
+ assert self._sklearn_object is not None
711
+ return self._sklearn_object.labels_
712
+
713
+
714
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
715
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
716
+ """
717
+ Returns:
718
+ Transformed dataset.
719
+ """
720
+ self.fit(dataset)
721
+ assert self._sklearn_object is not None
722
+ return self._sklearn_object.embedding_
723
+
689
724
 
690
725
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
691
726
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MiniBatchSparsePCA(BaseTransformer):
58
70
  r"""Mini-batch Sparse Principal Components Analysis
59
71
  For more details on this class, see [sklearn.decomposition.MiniBatchSparsePCA]
@@ -203,7 +215,9 @@ class MiniBatchSparsePCA(BaseTransformer):
203
215
  self.set_label_cols(label_cols)
204
216
  self.set_passthrough_cols(passthrough_cols)
205
217
  self.set_drop_input_cols(drop_input_cols)
206
- self.set_sample_weight_col(sample_weight_col)
218
+ self.set_sample_weight_col(sample_weight_col)
219
+ self._use_external_memory_version = False
220
+ self._batch_size = -1
207
221
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
208
222
 
209
223
  self._deps = list(deps)
@@ -292,11 +306,6 @@ class MiniBatchSparsePCA(BaseTransformer):
292
306
  if isinstance(dataset, DataFrame):
293
307
  session = dataset._session
294
308
  assert session is not None # keep mypy happy
295
- # Validate that key package version in user workspace are supported in snowflake conda channel
296
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
297
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
298
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
299
-
300
309
  # Specify input columns so column pruning will be enforced
301
310
  selected_cols = self._get_active_columns()
302
311
  if len(selected_cols) > 0:
@@ -324,7 +333,9 @@ class MiniBatchSparsePCA(BaseTransformer):
324
333
  label_cols=self.label_cols,
325
334
  sample_weight_col=self.sample_weight_col,
326
335
  autogenerated=self._autogenerated,
327
- subproject=_SUBPROJECT
336
+ subproject=_SUBPROJECT,
337
+ use_external_memory_version=self._use_external_memory_version,
338
+ batch_size=self._batch_size,
328
339
  )
329
340
  self._sklearn_object = model_trainer.train()
330
341
  self._is_fitted = True
@@ -595,6 +606,22 @@ class MiniBatchSparsePCA(BaseTransformer):
595
606
  # each row containing a list of values.
596
607
  expected_dtype = "ARRAY"
597
608
 
609
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
610
+ if expected_dtype == "":
611
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
612
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
613
+ expected_dtype = "ARRAY"
614
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
615
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
616
+ expected_dtype = "ARRAY"
617
+ else:
618
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
619
+ # We can only infer the output types from the input types if the following two statemetns are true:
620
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
621
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
622
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
623
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
624
+
598
625
  output_df = self._batch_inference(
599
626
  dataset=dataset,
600
627
  inference_method="transform",
@@ -610,8 +637,8 @@ class MiniBatchSparsePCA(BaseTransformer):
610
637
 
611
638
  return output_df
612
639
 
613
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
614
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
640
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
641
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
615
642
  """ Method not supported for this class.
616
643
 
617
644
 
@@ -624,13 +651,21 @@ class MiniBatchSparsePCA(BaseTransformer):
624
651
  Returns:
625
652
  Predicted dataset.
626
653
  """
627
- if False:
628
- self.fit(dataset)
629
- assert self._sklearn_object is not None
630
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
631
- return labels
632
- else:
633
- raise NotImplementedError
654
+ self.fit(dataset)
655
+ assert self._sklearn_object is not None
656
+ return self._sklearn_object.labels_
657
+
658
+
659
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
660
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
661
+ """
662
+ Returns:
663
+ Transformed dataset.
664
+ """
665
+ self.fit(dataset)
666
+ assert self._sklearn_object is not None
667
+ return self._sklearn_object.embedding_
668
+
634
669
 
635
670
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
636
671
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class PCA(BaseTransformer):
58
70
  r"""Principal component analysis (PCA)
59
71
  For more details on this class, see [sklearn.decomposition.PCA]
@@ -210,7 +222,9 @@ class PCA(BaseTransformer):
210
222
  self.set_label_cols(label_cols)
211
223
  self.set_passthrough_cols(passthrough_cols)
212
224
  self.set_drop_input_cols(drop_input_cols)
213
- self.set_sample_weight_col(sample_weight_col)
225
+ self.set_sample_weight_col(sample_weight_col)
226
+ self._use_external_memory_version = False
227
+ self._batch_size = -1
214
228
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
215
229
 
216
230
  self._deps = list(deps)
@@ -294,11 +308,6 @@ class PCA(BaseTransformer):
294
308
  if isinstance(dataset, DataFrame):
295
309
  session = dataset._session
296
310
  assert session is not None # keep mypy happy
297
- # Validate that key package version in user workspace are supported in snowflake conda channel
298
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
299
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
300
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
301
-
302
311
  # Specify input columns so column pruning will be enforced
303
312
  selected_cols = self._get_active_columns()
304
313
  if len(selected_cols) > 0:
@@ -326,7 +335,9 @@ class PCA(BaseTransformer):
326
335
  label_cols=self.label_cols,
327
336
  sample_weight_col=self.sample_weight_col,
328
337
  autogenerated=self._autogenerated,
329
- subproject=_SUBPROJECT
338
+ subproject=_SUBPROJECT,
339
+ use_external_memory_version=self._use_external_memory_version,
340
+ batch_size=self._batch_size,
330
341
  )
331
342
  self._sklearn_object = model_trainer.train()
332
343
  self._is_fitted = True
@@ -597,6 +608,22 @@ class PCA(BaseTransformer):
597
608
  # each row containing a list of values.
598
609
  expected_dtype = "ARRAY"
599
610
 
611
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
612
+ if expected_dtype == "":
613
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
614
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
615
+ expected_dtype = "ARRAY"
616
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
617
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
618
+ expected_dtype = "ARRAY"
619
+ else:
620
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
621
+ # We can only infer the output types from the input types if the following two statemetns are true:
622
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
623
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
624
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
625
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
626
+
600
627
  output_df = self._batch_inference(
601
628
  dataset=dataset,
602
629
  inference_method="transform",
@@ -612,8 +639,8 @@ class PCA(BaseTransformer):
612
639
 
613
640
  return output_df
614
641
 
615
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
616
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
642
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
643
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
617
644
  """ Method not supported for this class.
618
645
 
619
646
 
@@ -626,13 +653,21 @@ class PCA(BaseTransformer):
626
653
  Returns:
627
654
  Predicted dataset.
628
655
  """
629
- if False:
630
- self.fit(dataset)
631
- assert self._sklearn_object is not None
632
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
633
- return labels
634
- else:
635
- raise NotImplementedError
656
+ self.fit(dataset)
657
+ assert self._sklearn_object is not None
658
+ return self._sklearn_object.labels_
659
+
660
+
661
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
662
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
663
+ """
664
+ Returns:
665
+ Transformed dataset.
666
+ """
667
+ self.fit(dataset)
668
+ assert self._sklearn_object is not None
669
+ return self._sklearn_object.embedding_
670
+
636
671
 
637
672
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
638
673
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.