snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
55
55
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
56
56
 
57
57
 
58
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
59
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
60
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
61
+ return check
62
+
63
+
64
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
+ return check
68
+
69
+
58
70
  class SelectPercentile(BaseTransformer):
59
71
  r"""Select features according to a percentile of the highest scores
60
72
  For more details on this class, see [sklearn.feature_selection.SelectPercentile]
@@ -136,7 +148,9 @@ class SelectPercentile(BaseTransformer):
136
148
  self.set_label_cols(label_cols)
137
149
  self.set_passthrough_cols(passthrough_cols)
138
150
  self.set_drop_input_cols(drop_input_cols)
139
- self.set_sample_weight_col(sample_weight_col)
151
+ self.set_sample_weight_col(sample_weight_col)
152
+ self._use_external_memory_version = False
153
+ self._batch_size = -1
140
154
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
141
155
 
142
156
  self._deps = list(deps)
@@ -213,11 +227,6 @@ class SelectPercentile(BaseTransformer):
213
227
  if isinstance(dataset, DataFrame):
214
228
  session = dataset._session
215
229
  assert session is not None # keep mypy happy
216
- # Validate that key package version in user workspace are supported in snowflake conda channel
217
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
218
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
219
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
220
-
221
230
  # Specify input columns so column pruning will be enforced
222
231
  selected_cols = self._get_active_columns()
223
232
  if len(selected_cols) > 0:
@@ -245,7 +254,9 @@ class SelectPercentile(BaseTransformer):
245
254
  label_cols=self.label_cols,
246
255
  sample_weight_col=self.sample_weight_col,
247
256
  autogenerated=self._autogenerated,
248
- subproject=_SUBPROJECT
257
+ subproject=_SUBPROJECT,
258
+ use_external_memory_version=self._use_external_memory_version,
259
+ batch_size=self._batch_size,
249
260
  )
250
261
  self._sklearn_object = model_trainer.train()
251
262
  self._is_fitted = True
@@ -516,6 +527,22 @@ class SelectPercentile(BaseTransformer):
516
527
  # each row containing a list of values.
517
528
  expected_dtype = "ARRAY"
518
529
 
530
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
531
+ if expected_dtype == "":
532
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
533
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
534
+ expected_dtype = "ARRAY"
535
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
536
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
537
+ expected_dtype = "ARRAY"
538
+ else:
539
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
540
+ # We can only infer the output types from the input types if the following two statemetns are true:
541
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
542
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
543
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
544
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
545
+
519
546
  output_df = self._batch_inference(
520
547
  dataset=dataset,
521
548
  inference_method="transform",
@@ -531,8 +558,8 @@ class SelectPercentile(BaseTransformer):
531
558
 
532
559
  return output_df
533
560
 
534
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
535
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
561
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
562
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
536
563
  """ Method not supported for this class.
537
564
 
538
565
 
@@ -545,13 +572,21 @@ class SelectPercentile(BaseTransformer):
545
572
  Returns:
546
573
  Predicted dataset.
547
574
  """
548
- if False:
549
- self.fit(dataset)
550
- assert self._sklearn_object is not None
551
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
552
- return labels
553
- else:
554
- raise NotImplementedError
575
+ self.fit(dataset)
576
+ assert self._sklearn_object is not None
577
+ return self._sklearn_object.labels_
578
+
579
+
580
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
581
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
582
+ """
583
+ Returns:
584
+ Transformed dataset.
585
+ """
586
+ self.fit(dataset)
587
+ assert self._sklearn_object is not None
588
+ return self._sklearn_object.embedding_
589
+
555
590
 
556
591
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
557
592
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SequentialFeatureSelector(BaseTransformer):
58
70
  r"""Transformer that performs Sequential Feature Selection
59
71
  For more details on this class, see [sklearn.feature_selection.SequentialFeatureSelector]
@@ -189,7 +201,9 @@ class SequentialFeatureSelector(BaseTransformer):
189
201
  self.set_label_cols(label_cols)
190
202
  self.set_passthrough_cols(passthrough_cols)
191
203
  self.set_drop_input_cols(drop_input_cols)
192
- self.set_sample_weight_col(sample_weight_col)
204
+ self.set_sample_weight_col(sample_weight_col)
205
+ self._use_external_memory_version = False
206
+ self._batch_size = -1
193
207
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
194
208
  deps = deps | gather_dependencies(estimator)
195
209
  self._deps = list(deps)
@@ -271,11 +285,6 @@ class SequentialFeatureSelector(BaseTransformer):
271
285
  if isinstance(dataset, DataFrame):
272
286
  session = dataset._session
273
287
  assert session is not None # keep mypy happy
274
- # Validate that key package version in user workspace are supported in snowflake conda channel
275
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
276
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
277
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
278
-
279
288
  # Specify input columns so column pruning will be enforced
280
289
  selected_cols = self._get_active_columns()
281
290
  if len(selected_cols) > 0:
@@ -303,7 +312,9 @@ class SequentialFeatureSelector(BaseTransformer):
303
312
  label_cols=self.label_cols,
304
313
  sample_weight_col=self.sample_weight_col,
305
314
  autogenerated=self._autogenerated,
306
- subproject=_SUBPROJECT
315
+ subproject=_SUBPROJECT,
316
+ use_external_memory_version=self._use_external_memory_version,
317
+ batch_size=self._batch_size,
307
318
  )
308
319
  self._sklearn_object = model_trainer.train()
309
320
  self._is_fitted = True
@@ -574,6 +585,22 @@ class SequentialFeatureSelector(BaseTransformer):
574
585
  # each row containing a list of values.
575
586
  expected_dtype = "ARRAY"
576
587
 
588
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
589
+ if expected_dtype == "":
590
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
591
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
592
+ expected_dtype = "ARRAY"
593
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
594
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
595
+ expected_dtype = "ARRAY"
596
+ else:
597
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
598
+ # We can only infer the output types from the input types if the following two statemetns are true:
599
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
600
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
601
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
602
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
603
+
577
604
  output_df = self._batch_inference(
578
605
  dataset=dataset,
579
606
  inference_method="transform",
@@ -589,8 +616,8 @@ class SequentialFeatureSelector(BaseTransformer):
589
616
 
590
617
  return output_df
591
618
 
592
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
593
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
619
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
620
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
594
621
  """ Method not supported for this class.
595
622
 
596
623
 
@@ -603,13 +630,21 @@ class SequentialFeatureSelector(BaseTransformer):
603
630
  Returns:
604
631
  Predicted dataset.
605
632
  """
606
- if False:
607
- self.fit(dataset)
608
- assert self._sklearn_object is not None
609
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
610
- return labels
611
- else:
612
- raise NotImplementedError
633
+ self.fit(dataset)
634
+ assert self._sklearn_object is not None
635
+ return self._sklearn_object.labels_
636
+
637
+
638
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
639
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
640
+ """
641
+ Returns:
642
+ Transformed dataset.
643
+ """
644
+ self.fit(dataset)
645
+ assert self._sklearn_object is not None
646
+ return self._sklearn_object.embedding_
647
+
613
648
 
614
649
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
615
650
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class VarianceThreshold(BaseTransformer):
58
70
  r"""Feature selector that removes all low-variance features
59
71
  For more details on this class, see [sklearn.feature_selection.VarianceThreshold]
@@ -128,7 +140,9 @@ class VarianceThreshold(BaseTransformer):
128
140
  self.set_label_cols(label_cols)
129
141
  self.set_passthrough_cols(passthrough_cols)
130
142
  self.set_drop_input_cols(drop_input_cols)
131
- self.set_sample_weight_col(sample_weight_col)
143
+ self.set_sample_weight_col(sample_weight_col)
144
+ self._use_external_memory_version = False
145
+ self._batch_size = -1
132
146
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
133
147
 
134
148
  self._deps = list(deps)
@@ -204,11 +218,6 @@ class VarianceThreshold(BaseTransformer):
204
218
  if isinstance(dataset, DataFrame):
205
219
  session = dataset._session
206
220
  assert session is not None # keep mypy happy
207
- # Validate that key package version in user workspace are supported in snowflake conda channel
208
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
209
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
210
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
211
-
212
221
  # Specify input columns so column pruning will be enforced
213
222
  selected_cols = self._get_active_columns()
214
223
  if len(selected_cols) > 0:
@@ -236,7 +245,9 @@ class VarianceThreshold(BaseTransformer):
236
245
  label_cols=self.label_cols,
237
246
  sample_weight_col=self.sample_weight_col,
238
247
  autogenerated=self._autogenerated,
239
- subproject=_SUBPROJECT
248
+ subproject=_SUBPROJECT,
249
+ use_external_memory_version=self._use_external_memory_version,
250
+ batch_size=self._batch_size,
240
251
  )
241
252
  self._sklearn_object = model_trainer.train()
242
253
  self._is_fitted = True
@@ -507,6 +518,22 @@ class VarianceThreshold(BaseTransformer):
507
518
  # each row containing a list of values.
508
519
  expected_dtype = "ARRAY"
509
520
 
521
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
522
+ if expected_dtype == "":
523
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
524
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
525
+ expected_dtype = "ARRAY"
526
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
527
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
528
+ expected_dtype = "ARRAY"
529
+ else:
530
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
531
+ # We can only infer the output types from the input types if the following two statemetns are true:
532
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
533
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
534
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
535
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
536
+
510
537
  output_df = self._batch_inference(
511
538
  dataset=dataset,
512
539
  inference_method="transform",
@@ -522,8 +549,8 @@ class VarianceThreshold(BaseTransformer):
522
549
 
523
550
  return output_df
524
551
 
525
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
526
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
552
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
553
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
527
554
  """ Method not supported for this class.
528
555
 
529
556
 
@@ -536,13 +563,21 @@ class VarianceThreshold(BaseTransformer):
536
563
  Returns:
537
564
  Predicted dataset.
538
565
  """
539
- if False:
540
- self.fit(dataset)
541
- assert self._sklearn_object is not None
542
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
543
- return labels
544
- else:
545
- raise NotImplementedError
566
+ self.fit(dataset)
567
+ assert self._sklearn_object is not None
568
+ return self._sklearn_object.labels_
569
+
570
+
571
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
572
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
573
+ """
574
+ Returns:
575
+ Transformed dataset.
576
+ """
577
+ self.fit(dataset)
578
+ assert self._sklearn_object is not None
579
+ return self._sklearn_object.embedding_
580
+
546
581
 
547
582
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
548
583
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.gaussian_process".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GaussianProcessClassifier(BaseTransformer):
58
70
  r"""Gaussian process classification (GPC) based on Laplace approximation
59
71
  For more details on this class, see [sklearn.gaussian_process.GaussianProcessClassifier]
@@ -215,7 +227,9 @@ class GaussianProcessClassifier(BaseTransformer):
215
227
  self.set_label_cols(label_cols)
216
228
  self.set_passthrough_cols(passthrough_cols)
217
229
  self.set_drop_input_cols(drop_input_cols)
218
- self.set_sample_weight_col(sample_weight_col)
230
+ self.set_sample_weight_col(sample_weight_col)
231
+ self._use_external_memory_version = False
232
+ self._batch_size = -1
219
233
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
220
234
 
221
235
  self._deps = list(deps)
@@ -299,11 +313,6 @@ class GaussianProcessClassifier(BaseTransformer):
299
313
  if isinstance(dataset, DataFrame):
300
314
  session = dataset._session
301
315
  assert session is not None # keep mypy happy
302
- # Validate that key package version in user workspace are supported in snowflake conda channel
303
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
304
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
305
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
306
-
307
316
  # Specify input columns so column pruning will be enforced
308
317
  selected_cols = self._get_active_columns()
309
318
  if len(selected_cols) > 0:
@@ -331,7 +340,9 @@ class GaussianProcessClassifier(BaseTransformer):
331
340
  label_cols=self.label_cols,
332
341
  sample_weight_col=self.sample_weight_col,
333
342
  autogenerated=self._autogenerated,
334
- subproject=_SUBPROJECT
343
+ subproject=_SUBPROJECT,
344
+ use_external_memory_version=self._use_external_memory_version,
345
+ batch_size=self._batch_size,
335
346
  )
336
347
  self._sklearn_object = model_trainer.train()
337
348
  self._is_fitted = True
@@ -602,6 +613,22 @@ class GaussianProcessClassifier(BaseTransformer):
602
613
  # each row containing a list of values.
603
614
  expected_dtype = "ARRAY"
604
615
 
616
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
617
+ if expected_dtype == "":
618
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
619
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
620
+ expected_dtype = "ARRAY"
621
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
622
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
623
+ expected_dtype = "ARRAY"
624
+ else:
625
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
626
+ # We can only infer the output types from the input types if the following two statemetns are true:
627
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
628
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
629
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
630
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
631
+
605
632
  output_df = self._batch_inference(
606
633
  dataset=dataset,
607
634
  inference_method="transform",
@@ -617,8 +644,8 @@ class GaussianProcessClassifier(BaseTransformer):
617
644
 
618
645
  return output_df
619
646
 
620
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
621
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
647
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
648
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
622
649
  """ Method not supported for this class.
623
650
 
624
651
 
@@ -631,13 +658,21 @@ class GaussianProcessClassifier(BaseTransformer):
631
658
  Returns:
632
659
  Predicted dataset.
633
660
  """
634
- if False:
635
- self.fit(dataset)
636
- assert self._sklearn_object is not None
637
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
638
- return labels
639
- else:
640
- raise NotImplementedError
661
+ self.fit(dataset)
662
+ assert self._sklearn_object is not None
663
+ return self._sklearn_object.labels_
664
+
665
+
666
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
667
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
668
+ """
669
+ Returns:
670
+ Transformed dataset.
671
+ """
672
+ self.fit(dataset)
673
+ assert self._sklearn_object is not None
674
+ return self._sklearn_object.embedding_
675
+
641
676
 
642
677
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
643
678
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.gaussian_process".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GaussianProcessRegressor(BaseTransformer):
58
70
  r"""Gaussian process regression (GPR)
59
71
  For more details on this class, see [sklearn.gaussian_process.GaussianProcessRegressor]
@@ -207,7 +219,9 @@ class GaussianProcessRegressor(BaseTransformer):
207
219
  self.set_label_cols(label_cols)
208
220
  self.set_passthrough_cols(passthrough_cols)
209
221
  self.set_drop_input_cols(drop_input_cols)
210
- self.set_sample_weight_col(sample_weight_col)
222
+ self.set_sample_weight_col(sample_weight_col)
223
+ self._use_external_memory_version = False
224
+ self._batch_size = -1
211
225
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
212
226
 
213
227
  self._deps = list(deps)
@@ -290,11 +304,6 @@ class GaussianProcessRegressor(BaseTransformer):
290
304
  if isinstance(dataset, DataFrame):
291
305
  session = dataset._session
292
306
  assert session is not None # keep mypy happy
293
- # Validate that key package version in user workspace are supported in snowflake conda channel
294
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
295
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
296
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
297
-
298
307
  # Specify input columns so column pruning will be enforced
299
308
  selected_cols = self._get_active_columns()
300
309
  if len(selected_cols) > 0:
@@ -322,7 +331,9 @@ class GaussianProcessRegressor(BaseTransformer):
322
331
  label_cols=self.label_cols,
323
332
  sample_weight_col=self.sample_weight_col,
324
333
  autogenerated=self._autogenerated,
325
- subproject=_SUBPROJECT
334
+ subproject=_SUBPROJECT,
335
+ use_external_memory_version=self._use_external_memory_version,
336
+ batch_size=self._batch_size,
326
337
  )
327
338
  self._sklearn_object = model_trainer.train()
328
339
  self._is_fitted = True
@@ -593,6 +604,22 @@ class GaussianProcessRegressor(BaseTransformer):
593
604
  # each row containing a list of values.
594
605
  expected_dtype = "ARRAY"
595
606
 
607
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
608
+ if expected_dtype == "":
609
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
610
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
611
+ expected_dtype = "ARRAY"
612
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
613
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
614
+ expected_dtype = "ARRAY"
615
+ else:
616
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
617
+ # We can only infer the output types from the input types if the following two statemetns are true:
618
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
619
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
620
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
621
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
622
+
596
623
  output_df = self._batch_inference(
597
624
  dataset=dataset,
598
625
  inference_method="transform",
@@ -608,8 +635,8 @@ class GaussianProcessRegressor(BaseTransformer):
608
635
 
609
636
  return output_df
610
637
 
611
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
612
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
638
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
639
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
613
640
  """ Method not supported for this class.
614
641
 
615
642
 
@@ -622,13 +649,21 @@ class GaussianProcessRegressor(BaseTransformer):
622
649
  Returns:
623
650
  Predicted dataset.
624
651
  """
625
- if False:
626
- self.fit(dataset)
627
- assert self._sklearn_object is not None
628
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
629
- return labels
630
- else:
631
- raise NotImplementedError
652
+ self.fit(dataset)
653
+ assert self._sklearn_object is not None
654
+ return self._sklearn_object.labels_
655
+
656
+
657
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
658
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
659
+ """
660
+ Returns:
661
+ Transformed dataset.
662
+ """
663
+ self.fit(dataset)
664
+ assert self._sklearn_object is not None
665
+ return self._sklearn_object.embedding_
666
+
632
667
 
633
668
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
634
669
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.