snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class OAS(BaseTransformer):
58
70
  r"""Oracle Approximating Shrinkage Estimator as proposed in [1]_
59
71
  For more details on this class, see [sklearn.covariance.OAS]
@@ -133,7 +145,9 @@ class OAS(BaseTransformer):
133
145
  self.set_label_cols(label_cols)
134
146
  self.set_passthrough_cols(passthrough_cols)
135
147
  self.set_drop_input_cols(drop_input_cols)
136
- self.set_sample_weight_col(sample_weight_col)
148
+ self.set_sample_weight_col(sample_weight_col)
149
+ self._use_external_memory_version = False
150
+ self._batch_size = -1
137
151
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
138
152
 
139
153
  self._deps = list(deps)
@@ -210,11 +224,6 @@ class OAS(BaseTransformer):
210
224
  if isinstance(dataset, DataFrame):
211
225
  session = dataset._session
212
226
  assert session is not None # keep mypy happy
213
- # Validate that key package version in user workspace are supported in snowflake conda channel
214
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
215
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
216
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
217
-
218
227
  # Specify input columns so column pruning will be enforced
219
228
  selected_cols = self._get_active_columns()
220
229
  if len(selected_cols) > 0:
@@ -242,7 +251,9 @@ class OAS(BaseTransformer):
242
251
  label_cols=self.label_cols,
243
252
  sample_weight_col=self.sample_weight_col,
244
253
  autogenerated=self._autogenerated,
245
- subproject=_SUBPROJECT
254
+ subproject=_SUBPROJECT,
255
+ use_external_memory_version=self._use_external_memory_version,
256
+ batch_size=self._batch_size,
246
257
  )
247
258
  self._sklearn_object = model_trainer.train()
248
259
  self._is_fitted = True
@@ -511,6 +522,22 @@ class OAS(BaseTransformer):
511
522
  # each row containing a list of values.
512
523
  expected_dtype = "ARRAY"
513
524
 
525
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
526
+ if expected_dtype == "":
527
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
528
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
529
+ expected_dtype = "ARRAY"
530
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
531
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
532
+ expected_dtype = "ARRAY"
533
+ else:
534
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
535
+ # We can only infer the output types from the input types if the following two statemetns are true:
536
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
537
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
538
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
539
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
540
+
514
541
  output_df = self._batch_inference(
515
542
  dataset=dataset,
516
543
  inference_method="transform",
@@ -526,8 +553,8 @@ class OAS(BaseTransformer):
526
553
 
527
554
  return output_df
528
555
 
529
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
530
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
556
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
557
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
531
558
  """ Method not supported for this class.
532
559
 
533
560
 
@@ -540,13 +567,21 @@ class OAS(BaseTransformer):
540
567
  Returns:
541
568
  Predicted dataset.
542
569
  """
543
- if False:
544
- self.fit(dataset)
545
- assert self._sklearn_object is not None
546
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
547
- return labels
548
- else:
549
- raise NotImplementedError
570
+ self.fit(dataset)
571
+ assert self._sklearn_object is not None
572
+ return self._sklearn_object.labels_
573
+
574
+
575
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
576
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
577
+ """
578
+ Returns:
579
+ Transformed dataset.
580
+ """
581
+ self.fit(dataset)
582
+ assert self._sklearn_object is not None
583
+ return self._sklearn_object.embedding_
584
+
550
585
 
551
586
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
552
587
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ShrunkCovariance(BaseTransformer):
58
70
  r"""Covariance estimator with shrinkage
59
71
  For more details on this class, see [sklearn.covariance.ShrunkCovariance]
@@ -138,7 +150,9 @@ class ShrunkCovariance(BaseTransformer):
138
150
  self.set_label_cols(label_cols)
139
151
  self.set_passthrough_cols(passthrough_cols)
140
152
  self.set_drop_input_cols(drop_input_cols)
141
- self.set_sample_weight_col(sample_weight_col)
153
+ self.set_sample_weight_col(sample_weight_col)
154
+ self._use_external_memory_version = False
155
+ self._batch_size = -1
142
156
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
143
157
 
144
158
  self._deps = list(deps)
@@ -216,11 +230,6 @@ class ShrunkCovariance(BaseTransformer):
216
230
  if isinstance(dataset, DataFrame):
217
231
  session = dataset._session
218
232
  assert session is not None # keep mypy happy
219
- # Validate that key package version in user workspace are supported in snowflake conda channel
220
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
221
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
222
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
223
-
224
233
  # Specify input columns so column pruning will be enforced
225
234
  selected_cols = self._get_active_columns()
226
235
  if len(selected_cols) > 0:
@@ -248,7 +257,9 @@ class ShrunkCovariance(BaseTransformer):
248
257
  label_cols=self.label_cols,
249
258
  sample_weight_col=self.sample_weight_col,
250
259
  autogenerated=self._autogenerated,
251
- subproject=_SUBPROJECT
260
+ subproject=_SUBPROJECT,
261
+ use_external_memory_version=self._use_external_memory_version,
262
+ batch_size=self._batch_size,
252
263
  )
253
264
  self._sklearn_object = model_trainer.train()
254
265
  self._is_fitted = True
@@ -517,6 +528,22 @@ class ShrunkCovariance(BaseTransformer):
517
528
  # each row containing a list of values.
518
529
  expected_dtype = "ARRAY"
519
530
 
531
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
532
+ if expected_dtype == "":
533
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
534
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
535
+ expected_dtype = "ARRAY"
536
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
537
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
538
+ expected_dtype = "ARRAY"
539
+ else:
540
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
541
+ # We can only infer the output types from the input types if the following two statemetns are true:
542
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
543
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
544
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
545
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
546
+
520
547
  output_df = self._batch_inference(
521
548
  dataset=dataset,
522
549
  inference_method="transform",
@@ -532,8 +559,8 @@ class ShrunkCovariance(BaseTransformer):
532
559
 
533
560
  return output_df
534
561
 
535
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
536
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
562
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
563
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
537
564
  """ Method not supported for this class.
538
565
 
539
566
 
@@ -546,13 +573,21 @@ class ShrunkCovariance(BaseTransformer):
546
573
  Returns:
547
574
  Predicted dataset.
548
575
  """
549
- if False:
550
- self.fit(dataset)
551
- assert self._sklearn_object is not None
552
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
553
- return labels
554
- else:
555
- raise NotImplementedError
576
+ self.fit(dataset)
577
+ assert self._sklearn_object is not None
578
+ return self._sklearn_object.labels_
579
+
580
+
581
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
582
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
583
+ """
584
+ Returns:
585
+ Transformed dataset.
586
+ """
587
+ self.fit(dataset)
588
+ assert self._sklearn_object is not None
589
+ return self._sklearn_object.embedding_
590
+
556
591
 
557
592
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
558
593
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class DictionaryLearning(BaseTransformer):
58
70
  r"""Dictionary learning
59
71
  For more details on this class, see [sklearn.decomposition.DictionaryLearning]
@@ -229,7 +241,9 @@ class DictionaryLearning(BaseTransformer):
229
241
  self.set_label_cols(label_cols)
230
242
  self.set_passthrough_cols(passthrough_cols)
231
243
  self.set_drop_input_cols(drop_input_cols)
232
- self.set_sample_weight_col(sample_weight_col)
244
+ self.set_sample_weight_col(sample_weight_col)
245
+ self._use_external_memory_version = False
246
+ self._batch_size = -1
233
247
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
234
248
 
235
249
  self._deps = list(deps)
@@ -322,11 +336,6 @@ class DictionaryLearning(BaseTransformer):
322
336
  if isinstance(dataset, DataFrame):
323
337
  session = dataset._session
324
338
  assert session is not None # keep mypy happy
325
- # Validate that key package version in user workspace are supported in snowflake conda channel
326
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
327
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
328
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
329
-
330
339
  # Specify input columns so column pruning will be enforced
331
340
  selected_cols = self._get_active_columns()
332
341
  if len(selected_cols) > 0:
@@ -354,7 +363,9 @@ class DictionaryLearning(BaseTransformer):
354
363
  label_cols=self.label_cols,
355
364
  sample_weight_col=self.sample_weight_col,
356
365
  autogenerated=self._autogenerated,
357
- subproject=_SUBPROJECT
366
+ subproject=_SUBPROJECT,
367
+ use_external_memory_version=self._use_external_memory_version,
368
+ batch_size=self._batch_size,
358
369
  )
359
370
  self._sklearn_object = model_trainer.train()
360
371
  self._is_fitted = True
@@ -625,6 +636,22 @@ class DictionaryLearning(BaseTransformer):
625
636
  # each row containing a list of values.
626
637
  expected_dtype = "ARRAY"
627
638
 
639
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
640
+ if expected_dtype == "":
641
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
642
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
643
+ expected_dtype = "ARRAY"
644
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
645
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
646
+ expected_dtype = "ARRAY"
647
+ else:
648
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
649
+ # We can only infer the output types from the input types if the following two statemetns are true:
650
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
651
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
652
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
653
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
654
+
628
655
  output_df = self._batch_inference(
629
656
  dataset=dataset,
630
657
  inference_method="transform",
@@ -640,8 +667,8 @@ class DictionaryLearning(BaseTransformer):
640
667
 
641
668
  return output_df
642
669
 
643
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
644
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
670
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
671
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
645
672
  """ Method not supported for this class.
646
673
 
647
674
 
@@ -654,13 +681,21 @@ class DictionaryLearning(BaseTransformer):
654
681
  Returns:
655
682
  Predicted dataset.
656
683
  """
657
- if False:
658
- self.fit(dataset)
659
- assert self._sklearn_object is not None
660
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
661
- return labels
662
- else:
663
- raise NotImplementedError
684
+ self.fit(dataset)
685
+ assert self._sklearn_object is not None
686
+ return self._sklearn_object.labels_
687
+
688
+
689
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
690
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
691
+ """
692
+ Returns:
693
+ Transformed dataset.
694
+ """
695
+ self.fit(dataset)
696
+ assert self._sklearn_object is not None
697
+ return self._sklearn_object.embedding_
698
+
664
699
 
665
700
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
666
701
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class FactorAnalysis(BaseTransformer):
58
70
  r"""Factor Analysis (FA)
59
71
  For more details on this class, see [sklearn.decomposition.FactorAnalysis]
@@ -175,7 +187,9 @@ class FactorAnalysis(BaseTransformer):
175
187
  self.set_label_cols(label_cols)
176
188
  self.set_passthrough_cols(passthrough_cols)
177
189
  self.set_drop_input_cols(drop_input_cols)
178
- self.set_sample_weight_col(sample_weight_col)
190
+ self.set_sample_weight_col(sample_weight_col)
191
+ self._use_external_memory_version = False
192
+ self._batch_size = -1
179
193
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
180
194
 
181
195
  self._deps = list(deps)
@@ -259,11 +273,6 @@ class FactorAnalysis(BaseTransformer):
259
273
  if isinstance(dataset, DataFrame):
260
274
  session = dataset._session
261
275
  assert session is not None # keep mypy happy
262
- # Validate that key package version in user workspace are supported in snowflake conda channel
263
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
264
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
265
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
266
-
267
276
  # Specify input columns so column pruning will be enforced
268
277
  selected_cols = self._get_active_columns()
269
278
  if len(selected_cols) > 0:
@@ -291,7 +300,9 @@ class FactorAnalysis(BaseTransformer):
291
300
  label_cols=self.label_cols,
292
301
  sample_weight_col=self.sample_weight_col,
293
302
  autogenerated=self._autogenerated,
294
- subproject=_SUBPROJECT
303
+ subproject=_SUBPROJECT,
304
+ use_external_memory_version=self._use_external_memory_version,
305
+ batch_size=self._batch_size,
295
306
  )
296
307
  self._sklearn_object = model_trainer.train()
297
308
  self._is_fitted = True
@@ -562,6 +573,22 @@ class FactorAnalysis(BaseTransformer):
562
573
  # each row containing a list of values.
563
574
  expected_dtype = "ARRAY"
564
575
 
576
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
577
+ if expected_dtype == "":
578
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
579
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
580
+ expected_dtype = "ARRAY"
581
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
582
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
583
+ expected_dtype = "ARRAY"
584
+ else:
585
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
586
+ # We can only infer the output types from the input types if the following two statemetns are true:
587
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
588
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
589
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
590
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
591
+
565
592
  output_df = self._batch_inference(
566
593
  dataset=dataset,
567
594
  inference_method="transform",
@@ -577,8 +604,8 @@ class FactorAnalysis(BaseTransformer):
577
604
 
578
605
  return output_df
579
606
 
580
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
581
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
607
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
608
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
582
609
  """ Method not supported for this class.
583
610
 
584
611
 
@@ -591,13 +618,21 @@ class FactorAnalysis(BaseTransformer):
591
618
  Returns:
592
619
  Predicted dataset.
593
620
  """
594
- if False:
595
- self.fit(dataset)
596
- assert self._sklearn_object is not None
597
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
598
- return labels
599
- else:
600
- raise NotImplementedError
621
+ self.fit(dataset)
622
+ assert self._sklearn_object is not None
623
+ return self._sklearn_object.labels_
624
+
625
+
626
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
627
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
628
+ """
629
+ Returns:
630
+ Transformed dataset.
631
+ """
632
+ self.fit(dataset)
633
+ assert self._sklearn_object is not None
634
+ return self._sklearn_object.embedding_
635
+
601
636
 
602
637
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
603
638
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class FastICA(BaseTransformer):
58
70
  r"""FastICA: a fast algorithm for Independent Component Analysis
59
71
  For more details on this class, see [sklearn.decomposition.FastICA]
@@ -192,7 +204,9 @@ class FastICA(BaseTransformer):
192
204
  self.set_label_cols(label_cols)
193
205
  self.set_passthrough_cols(passthrough_cols)
194
206
  self.set_drop_input_cols(drop_input_cols)
195
- self.set_sample_weight_col(sample_weight_col)
207
+ self.set_sample_weight_col(sample_weight_col)
208
+ self._use_external_memory_version = False
209
+ self._batch_size = -1
196
210
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
197
211
 
198
212
  self._deps = list(deps)
@@ -277,11 +291,6 @@ class FastICA(BaseTransformer):
277
291
  if isinstance(dataset, DataFrame):
278
292
  session = dataset._session
279
293
  assert session is not None # keep mypy happy
280
- # Validate that key package version in user workspace are supported in snowflake conda channel
281
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
282
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
283
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
284
-
285
294
  # Specify input columns so column pruning will be enforced
286
295
  selected_cols = self._get_active_columns()
287
296
  if len(selected_cols) > 0:
@@ -309,7 +318,9 @@ class FastICA(BaseTransformer):
309
318
  label_cols=self.label_cols,
310
319
  sample_weight_col=self.sample_weight_col,
311
320
  autogenerated=self._autogenerated,
312
- subproject=_SUBPROJECT
321
+ subproject=_SUBPROJECT,
322
+ use_external_memory_version=self._use_external_memory_version,
323
+ batch_size=self._batch_size,
313
324
  )
314
325
  self._sklearn_object = model_trainer.train()
315
326
  self._is_fitted = True
@@ -580,6 +591,22 @@ class FastICA(BaseTransformer):
580
591
  # each row containing a list of values.
581
592
  expected_dtype = "ARRAY"
582
593
 
594
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
595
+ if expected_dtype == "":
596
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
597
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
598
+ expected_dtype = "ARRAY"
599
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
600
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
601
+ expected_dtype = "ARRAY"
602
+ else:
603
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
604
+ # We can only infer the output types from the input types if the following two statemetns are true:
605
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
606
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
607
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
608
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
609
+
583
610
  output_df = self._batch_inference(
584
611
  dataset=dataset,
585
612
  inference_method="transform",
@@ -595,8 +622,8 @@ class FastICA(BaseTransformer):
595
622
 
596
623
  return output_df
597
624
 
598
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
599
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
625
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
626
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
600
627
  """ Method not supported for this class.
601
628
 
602
629
 
@@ -609,13 +636,21 @@ class FastICA(BaseTransformer):
609
636
  Returns:
610
637
  Predicted dataset.
611
638
  """
612
- if False:
613
- self.fit(dataset)
614
- assert self._sklearn_object is not None
615
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
616
- return labels
617
- else:
618
- raise NotImplementedError
639
+ self.fit(dataset)
640
+ assert self._sklearn_object is not None
641
+ return self._sklearn_object.labels_
642
+
643
+
644
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
645
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
646
+ """
647
+ Returns:
648
+ Transformed dataset.
649
+ """
650
+ self.fit(dataset)
651
+ assert self._sklearn_object is not None
652
+ return self._sklearn_object.embedding_
653
+
619
654
 
620
655
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
621
656
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.