snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
55
55
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn.", "").split("_")])
56
56
 
57
57
 
58
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
59
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
60
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
61
+ return check
62
+
63
+
64
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
+ return check
68
+
69
+
58
70
  class IterativeImputer(BaseTransformer):
59
71
  r"""Multivariate imputer that estimates each feature from all the others
60
72
  For more details on this class, see [sklearn.impute.IterativeImputer]
@@ -241,7 +253,9 @@ class IterativeImputer(BaseTransformer):
241
253
  self.set_label_cols(label_cols)
242
254
  self.set_passthrough_cols(passthrough_cols)
243
255
  self.set_drop_input_cols(drop_input_cols)
244
- self.set_sample_weight_col(sample_weight_col)
256
+ self.set_sample_weight_col(sample_weight_col)
257
+ self._use_external_memory_version = False
258
+ self._batch_size = -1
245
259
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
246
260
  deps = deps | gather_dependencies(estimator)
247
261
  self._deps = list(deps)
@@ -332,11 +346,6 @@ class IterativeImputer(BaseTransformer):
332
346
  if isinstance(dataset, DataFrame):
333
347
  session = dataset._session
334
348
  assert session is not None # keep mypy happy
335
- # Validate that key package version in user workspace are supported in snowflake conda channel
336
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
337
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
338
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
339
-
340
349
  # Specify input columns so column pruning will be enforced
341
350
  selected_cols = self._get_active_columns()
342
351
  if len(selected_cols) > 0:
@@ -364,7 +373,9 @@ class IterativeImputer(BaseTransformer):
364
373
  label_cols=self.label_cols,
365
374
  sample_weight_col=self.sample_weight_col,
366
375
  autogenerated=self._autogenerated,
367
- subproject=_SUBPROJECT
376
+ subproject=_SUBPROJECT,
377
+ use_external_memory_version=self._use_external_memory_version,
378
+ batch_size=self._batch_size,
368
379
  )
369
380
  self._sklearn_object = model_trainer.train()
370
381
  self._is_fitted = True
@@ -635,6 +646,22 @@ class IterativeImputer(BaseTransformer):
635
646
  # each row containing a list of values.
636
647
  expected_dtype = "ARRAY"
637
648
 
649
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
650
+ if expected_dtype == "":
651
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
652
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
653
+ expected_dtype = "ARRAY"
654
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
655
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
656
+ expected_dtype = "ARRAY"
657
+ else:
658
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
659
+ # We can only infer the output types from the input types if the following two statemetns are true:
660
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
661
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
662
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
663
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
664
+
638
665
  output_df = self._batch_inference(
639
666
  dataset=dataset,
640
667
  inference_method="transform",
@@ -650,8 +677,8 @@ class IterativeImputer(BaseTransformer):
650
677
 
651
678
  return output_df
652
679
 
653
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
654
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
680
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
681
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
655
682
  """ Method not supported for this class.
656
683
 
657
684
 
@@ -664,13 +691,21 @@ class IterativeImputer(BaseTransformer):
664
691
  Returns:
665
692
  Predicted dataset.
666
693
  """
667
- if False:
668
- self.fit(dataset)
669
- assert self._sklearn_object is not None
670
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
671
- return labels
672
- else:
673
- raise NotImplementedError
694
+ self.fit(dataset)
695
+ assert self._sklearn_object is not None
696
+ return self._sklearn_object.labels_
697
+
698
+
699
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
700
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
701
+ """
702
+ Returns:
703
+ Transformed dataset.
704
+ """
705
+ self.fit(dataset)
706
+ assert self._sklearn_object is not None
707
+ return self._sklearn_object.embedding_
708
+
674
709
 
675
710
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
676
711
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class KNNImputer(BaseTransformer):
58
70
  r"""Imputation for completing missing values using k-Nearest Neighbors
59
71
  For more details on this class, see [sklearn.impute.KNNImputer]
@@ -176,7 +188,9 @@ class KNNImputer(BaseTransformer):
176
188
  self.set_label_cols(label_cols)
177
189
  self.set_passthrough_cols(passthrough_cols)
178
190
  self.set_drop_input_cols(drop_input_cols)
179
- self.set_sample_weight_col(sample_weight_col)
191
+ self.set_sample_weight_col(sample_weight_col)
192
+ self._use_external_memory_version = False
193
+ self._batch_size = -1
180
194
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
181
195
 
182
196
  self._deps = list(deps)
@@ -258,11 +272,6 @@ class KNNImputer(BaseTransformer):
258
272
  if isinstance(dataset, DataFrame):
259
273
  session = dataset._session
260
274
  assert session is not None # keep mypy happy
261
- # Validate that key package version in user workspace are supported in snowflake conda channel
262
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
263
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
264
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
265
-
266
275
  # Specify input columns so column pruning will be enforced
267
276
  selected_cols = self._get_active_columns()
268
277
  if len(selected_cols) > 0:
@@ -290,7 +299,9 @@ class KNNImputer(BaseTransformer):
290
299
  label_cols=self.label_cols,
291
300
  sample_weight_col=self.sample_weight_col,
292
301
  autogenerated=self._autogenerated,
293
- subproject=_SUBPROJECT
302
+ subproject=_SUBPROJECT,
303
+ use_external_memory_version=self._use_external_memory_version,
304
+ batch_size=self._batch_size,
294
305
  )
295
306
  self._sklearn_object = model_trainer.train()
296
307
  self._is_fitted = True
@@ -561,6 +572,22 @@ class KNNImputer(BaseTransformer):
561
572
  # each row containing a list of values.
562
573
  expected_dtype = "ARRAY"
563
574
 
575
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
576
+ if expected_dtype == "":
577
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
578
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
579
+ expected_dtype = "ARRAY"
580
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
581
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
582
+ expected_dtype = "ARRAY"
583
+ else:
584
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
585
+ # We can only infer the output types from the input types if the following two statemetns are true:
586
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
587
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
588
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
589
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
590
+
564
591
  output_df = self._batch_inference(
565
592
  dataset=dataset,
566
593
  inference_method="transform",
@@ -576,8 +603,8 @@ class KNNImputer(BaseTransformer):
576
603
 
577
604
  return output_df
578
605
 
579
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
580
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
606
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
607
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
581
608
  """ Method not supported for this class.
582
609
 
583
610
 
@@ -590,13 +617,21 @@ class KNNImputer(BaseTransformer):
590
617
  Returns:
591
618
  Predicted dataset.
592
619
  """
593
- if False:
594
- self.fit(dataset)
595
- assert self._sklearn_object is not None
596
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
597
- return labels
598
- else:
599
- raise NotImplementedError
620
+ self.fit(dataset)
621
+ assert self._sklearn_object is not None
622
+ return self._sklearn_object.labels_
623
+
624
+
625
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
626
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
627
+ """
628
+ Returns:
629
+ Transformed dataset.
630
+ """
631
+ self.fit(dataset)
632
+ assert self._sklearn_object is not None
633
+ return self._sklearn_object.embedding_
634
+
600
635
 
601
636
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
602
637
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MissingIndicator(BaseTransformer):
58
70
  r"""Binary indicators for missing values
59
71
  For more details on this class, see [sklearn.impute.MissingIndicator]
@@ -153,7 +165,9 @@ class MissingIndicator(BaseTransformer):
153
165
  self.set_label_cols(label_cols)
154
166
  self.set_passthrough_cols(passthrough_cols)
155
167
  self.set_drop_input_cols(drop_input_cols)
156
- self.set_sample_weight_col(sample_weight_col)
168
+ self.set_sample_weight_col(sample_weight_col)
169
+ self._use_external_memory_version = False
170
+ self._batch_size = -1
157
171
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
158
172
 
159
173
  self._deps = list(deps)
@@ -232,11 +246,6 @@ class MissingIndicator(BaseTransformer):
232
246
  if isinstance(dataset, DataFrame):
233
247
  session = dataset._session
234
248
  assert session is not None # keep mypy happy
235
- # Validate that key package version in user workspace are supported in snowflake conda channel
236
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
237
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
238
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
239
-
240
249
  # Specify input columns so column pruning will be enforced
241
250
  selected_cols = self._get_active_columns()
242
251
  if len(selected_cols) > 0:
@@ -264,7 +273,9 @@ class MissingIndicator(BaseTransformer):
264
273
  label_cols=self.label_cols,
265
274
  sample_weight_col=self.sample_weight_col,
266
275
  autogenerated=self._autogenerated,
267
- subproject=_SUBPROJECT
276
+ subproject=_SUBPROJECT,
277
+ use_external_memory_version=self._use_external_memory_version,
278
+ batch_size=self._batch_size,
268
279
  )
269
280
  self._sklearn_object = model_trainer.train()
270
281
  self._is_fitted = True
@@ -535,6 +546,22 @@ class MissingIndicator(BaseTransformer):
535
546
  # each row containing a list of values.
536
547
  expected_dtype = "ARRAY"
537
548
 
549
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
550
+ if expected_dtype == "":
551
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
552
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
553
+ expected_dtype = "ARRAY"
554
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
555
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
556
+ expected_dtype = "ARRAY"
557
+ else:
558
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
559
+ # We can only infer the output types from the input types if the following two statemetns are true:
560
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
561
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
562
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
563
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
564
+
538
565
  output_df = self._batch_inference(
539
566
  dataset=dataset,
540
567
  inference_method="transform",
@@ -550,8 +577,8 @@ class MissingIndicator(BaseTransformer):
550
577
 
551
578
  return output_df
552
579
 
553
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
554
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
580
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
581
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
555
582
  """ Method not supported for this class.
556
583
 
557
584
 
@@ -564,13 +591,21 @@ class MissingIndicator(BaseTransformer):
564
591
  Returns:
565
592
  Predicted dataset.
566
593
  """
567
- if False:
568
- self.fit(dataset)
569
- assert self._sklearn_object is not None
570
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
571
- return labels
572
- else:
573
- raise NotImplementedError
594
+ self.fit(dataset)
595
+ assert self._sklearn_object is not None
596
+ return self._sklearn_object.labels_
597
+
598
+
599
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
600
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
601
+ """
602
+ Returns:
603
+ Transformed dataset.
604
+ """
605
+ self.fit(dataset)
606
+ assert self._sklearn_object is not None
607
+ return self._sklearn_object.embedding_
608
+
574
609
 
575
610
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
576
611
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class AdditiveChi2Sampler(BaseTransformer):
58
70
  r"""Approximate feature map for additive chi2 kernel
59
71
  For more details on this class, see [sklearn.kernel_approximation.AdditiveChi2Sampler]
@@ -130,7 +142,9 @@ class AdditiveChi2Sampler(BaseTransformer):
130
142
  self.set_label_cols(label_cols)
131
143
  self.set_passthrough_cols(passthrough_cols)
132
144
  self.set_drop_input_cols(drop_input_cols)
133
- self.set_sample_weight_col(sample_weight_col)
145
+ self.set_sample_weight_col(sample_weight_col)
146
+ self._use_external_memory_version = False
147
+ self._batch_size = -1
134
148
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
135
149
 
136
150
  self._deps = list(deps)
@@ -207,11 +221,6 @@ class AdditiveChi2Sampler(BaseTransformer):
207
221
  if isinstance(dataset, DataFrame):
208
222
  session = dataset._session
209
223
  assert session is not None # keep mypy happy
210
- # Validate that key package version in user workspace are supported in snowflake conda channel
211
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
212
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
213
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
214
-
215
224
  # Specify input columns so column pruning will be enforced
216
225
  selected_cols = self._get_active_columns()
217
226
  if len(selected_cols) > 0:
@@ -239,7 +248,9 @@ class AdditiveChi2Sampler(BaseTransformer):
239
248
  label_cols=self.label_cols,
240
249
  sample_weight_col=self.sample_weight_col,
241
250
  autogenerated=self._autogenerated,
242
- subproject=_SUBPROJECT
251
+ subproject=_SUBPROJECT,
252
+ use_external_memory_version=self._use_external_memory_version,
253
+ batch_size=self._batch_size,
243
254
  )
244
255
  self._sklearn_object = model_trainer.train()
245
256
  self._is_fitted = True
@@ -510,6 +521,22 @@ class AdditiveChi2Sampler(BaseTransformer):
510
521
  # each row containing a list of values.
511
522
  expected_dtype = "ARRAY"
512
523
 
524
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
525
+ if expected_dtype == "":
526
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
527
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
528
+ expected_dtype = "ARRAY"
529
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
530
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
531
+ expected_dtype = "ARRAY"
532
+ else:
533
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
534
+ # We can only infer the output types from the input types if the following two statemetns are true:
535
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
536
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
537
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
538
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
539
+
513
540
  output_df = self._batch_inference(
514
541
  dataset=dataset,
515
542
  inference_method="transform",
@@ -525,8 +552,8 @@ class AdditiveChi2Sampler(BaseTransformer):
525
552
 
526
553
  return output_df
527
554
 
528
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
529
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
555
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
556
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
530
557
  """ Method not supported for this class.
531
558
 
532
559
 
@@ -539,13 +566,21 @@ class AdditiveChi2Sampler(BaseTransformer):
539
566
  Returns:
540
567
  Predicted dataset.
541
568
  """
542
- if False:
543
- self.fit(dataset)
544
- assert self._sklearn_object is not None
545
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
546
- return labels
547
- else:
548
- raise NotImplementedError
569
+ self.fit(dataset)
570
+ assert self._sklearn_object is not None
571
+ return self._sklearn_object.labels_
572
+
573
+
574
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
575
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
576
+ """
577
+ Returns:
578
+ Transformed dataset.
579
+ """
580
+ self.fit(dataset)
581
+ assert self._sklearn_object is not None
582
+ return self._sklearn_object.embedding_
583
+
549
584
 
550
585
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
551
586
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class Nystroem(BaseTransformer):
58
70
  r"""Approximate a kernel map using a subset of the training data
59
71
  For more details on this class, see [sklearn.kernel_approximation.Nystroem]
@@ -172,7 +184,9 @@ class Nystroem(BaseTransformer):
172
184
  self.set_label_cols(label_cols)
173
185
  self.set_passthrough_cols(passthrough_cols)
174
186
  self.set_drop_input_cols(drop_input_cols)
175
- self.set_sample_weight_col(sample_weight_col)
187
+ self.set_sample_weight_col(sample_weight_col)
188
+ self._use_external_memory_version = False
189
+ self._batch_size = -1
176
190
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
177
191
 
178
192
  self._deps = list(deps)
@@ -255,11 +269,6 @@ class Nystroem(BaseTransformer):
255
269
  if isinstance(dataset, DataFrame):
256
270
  session = dataset._session
257
271
  assert session is not None # keep mypy happy
258
- # Validate that key package version in user workspace are supported in snowflake conda channel
259
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
260
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
261
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
262
-
263
272
  # Specify input columns so column pruning will be enforced
264
273
  selected_cols = self._get_active_columns()
265
274
  if len(selected_cols) > 0:
@@ -287,7 +296,9 @@ class Nystroem(BaseTransformer):
287
296
  label_cols=self.label_cols,
288
297
  sample_weight_col=self.sample_weight_col,
289
298
  autogenerated=self._autogenerated,
290
- subproject=_SUBPROJECT
299
+ subproject=_SUBPROJECT,
300
+ use_external_memory_version=self._use_external_memory_version,
301
+ batch_size=self._batch_size,
291
302
  )
292
303
  self._sklearn_object = model_trainer.train()
293
304
  self._is_fitted = True
@@ -558,6 +569,22 @@ class Nystroem(BaseTransformer):
558
569
  # each row containing a list of values.
559
570
  expected_dtype = "ARRAY"
560
571
 
572
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
573
+ if expected_dtype == "":
574
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
575
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
576
+ expected_dtype = "ARRAY"
577
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
578
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
579
+ expected_dtype = "ARRAY"
580
+ else:
581
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
582
+ # We can only infer the output types from the input types if the following two statemetns are true:
583
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
584
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
585
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
586
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
587
+
561
588
  output_df = self._batch_inference(
562
589
  dataset=dataset,
563
590
  inference_method="transform",
@@ -573,8 +600,8 @@ class Nystroem(BaseTransformer):
573
600
 
574
601
  return output_df
575
602
 
576
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
577
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
603
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
604
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
578
605
  """ Method not supported for this class.
579
606
 
580
607
 
@@ -587,13 +614,21 @@ class Nystroem(BaseTransformer):
587
614
  Returns:
588
615
  Predicted dataset.
589
616
  """
590
- if False:
591
- self.fit(dataset)
592
- assert self._sklearn_object is not None
593
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
594
- return labels
595
- else:
596
- raise NotImplementedError
617
+ self.fit(dataset)
618
+ assert self._sklearn_object is not None
619
+ return self._sklearn_object.labels_
620
+
621
+
622
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
623
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
624
+ """
625
+ Returns:
626
+ Transformed dataset.
627
+ """
628
+ self.fit(dataset)
629
+ assert self._sklearn_object is not None
630
+ return self._sklearn_object.embedding_
631
+
597
632
 
598
633
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
599
634
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.