snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class PoissonRegressor(BaseTransformer):
58
70
  r"""Generalized Linear Model with a Poisson distribution
59
71
  For more details on this class, see [sklearn.linear_model.PoissonRegressor]
@@ -175,7 +187,9 @@ class PoissonRegressor(BaseTransformer):
175
187
  self.set_label_cols(label_cols)
176
188
  self.set_passthrough_cols(passthrough_cols)
177
189
  self.set_drop_input_cols(drop_input_cols)
178
- self.set_sample_weight_col(sample_weight_col)
190
+ self.set_sample_weight_col(sample_weight_col)
191
+ self._use_external_memory_version = False
192
+ self._batch_size = -1
179
193
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
180
194
 
181
195
  self._deps = list(deps)
@@ -257,11 +271,6 @@ class PoissonRegressor(BaseTransformer):
257
271
  if isinstance(dataset, DataFrame):
258
272
  session = dataset._session
259
273
  assert session is not None # keep mypy happy
260
- # Validate that key package version in user workspace are supported in snowflake conda channel
261
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
262
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
263
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
264
-
265
274
  # Specify input columns so column pruning will be enforced
266
275
  selected_cols = self._get_active_columns()
267
276
  if len(selected_cols) > 0:
@@ -289,7 +298,9 @@ class PoissonRegressor(BaseTransformer):
289
298
  label_cols=self.label_cols,
290
299
  sample_weight_col=self.sample_weight_col,
291
300
  autogenerated=self._autogenerated,
292
- subproject=_SUBPROJECT
301
+ subproject=_SUBPROJECT,
302
+ use_external_memory_version=self._use_external_memory_version,
303
+ batch_size=self._batch_size,
293
304
  )
294
305
  self._sklearn_object = model_trainer.train()
295
306
  self._is_fitted = True
@@ -560,6 +571,22 @@ class PoissonRegressor(BaseTransformer):
560
571
  # each row containing a list of values.
561
572
  expected_dtype = "ARRAY"
562
573
 
574
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
575
+ if expected_dtype == "":
576
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
577
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
578
+ expected_dtype = "ARRAY"
579
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
580
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
581
+ expected_dtype = "ARRAY"
582
+ else:
583
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
584
+ # We can only infer the output types from the input types if the following two statemetns are true:
585
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
586
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
587
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
588
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
589
+
563
590
  output_df = self._batch_inference(
564
591
  dataset=dataset,
565
592
  inference_method="transform",
@@ -575,8 +602,8 @@ class PoissonRegressor(BaseTransformer):
575
602
 
576
603
  return output_df
577
604
 
578
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
579
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
605
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
606
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
580
607
  """ Method not supported for this class.
581
608
 
582
609
 
@@ -589,13 +616,21 @@ class PoissonRegressor(BaseTransformer):
589
616
  Returns:
590
617
  Predicted dataset.
591
618
  """
592
- if False:
593
- self.fit(dataset)
594
- assert self._sklearn_object is not None
595
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
596
- return labels
597
- else:
598
- raise NotImplementedError
619
+ self.fit(dataset)
620
+ assert self._sklearn_object is not None
621
+ return self._sklearn_object.labels_
622
+
623
+
624
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
625
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
626
+ """
627
+ Returns:
628
+ Transformed dataset.
629
+ """
630
+ self.fit(dataset)
631
+ assert self._sklearn_object is not None
632
+ return self._sklearn_object.embedding_
633
+
599
634
 
600
635
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
601
636
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RANSACRegressor(BaseTransformer):
58
70
  r"""RANSAC (RANdom SAmple Consensus) algorithm
59
71
  For more details on this class, see [sklearn.linear_model.RANSACRegressor]
@@ -226,7 +238,9 @@ class RANSACRegressor(BaseTransformer):
226
238
  self.set_label_cols(label_cols)
227
239
  self.set_passthrough_cols(passthrough_cols)
228
240
  self.set_drop_input_cols(drop_input_cols)
229
- self.set_sample_weight_col(sample_weight_col)
241
+ self.set_sample_weight_col(sample_weight_col)
242
+ self._use_external_memory_version = False
243
+ self._batch_size = -1
230
244
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
231
245
  deps = deps | gather_dependencies(estimator)
232
246
  self._deps = list(deps)
@@ -313,11 +327,6 @@ class RANSACRegressor(BaseTransformer):
313
327
  if isinstance(dataset, DataFrame):
314
328
  session = dataset._session
315
329
  assert session is not None # keep mypy happy
316
- # Validate that key package version in user workspace are supported in snowflake conda channel
317
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
318
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
319
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
320
-
321
330
  # Specify input columns so column pruning will be enforced
322
331
  selected_cols = self._get_active_columns()
323
332
  if len(selected_cols) > 0:
@@ -345,7 +354,9 @@ class RANSACRegressor(BaseTransformer):
345
354
  label_cols=self.label_cols,
346
355
  sample_weight_col=self.sample_weight_col,
347
356
  autogenerated=self._autogenerated,
348
- subproject=_SUBPROJECT
357
+ subproject=_SUBPROJECT,
358
+ use_external_memory_version=self._use_external_memory_version,
359
+ batch_size=self._batch_size,
349
360
  )
350
361
  self._sklearn_object = model_trainer.train()
351
362
  self._is_fitted = True
@@ -616,6 +627,22 @@ class RANSACRegressor(BaseTransformer):
616
627
  # each row containing a list of values.
617
628
  expected_dtype = "ARRAY"
618
629
 
630
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
631
+ if expected_dtype == "":
632
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
633
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
634
+ expected_dtype = "ARRAY"
635
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
636
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
637
+ expected_dtype = "ARRAY"
638
+ else:
639
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
640
+ # We can only infer the output types from the input types if the following two statemetns are true:
641
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
642
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
643
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
644
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
645
+
619
646
  output_df = self._batch_inference(
620
647
  dataset=dataset,
621
648
  inference_method="transform",
@@ -631,8 +658,8 @@ class RANSACRegressor(BaseTransformer):
631
658
 
632
659
  return output_df
633
660
 
634
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
635
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
661
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
662
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
636
663
  """ Method not supported for this class.
637
664
 
638
665
 
@@ -645,13 +672,21 @@ class RANSACRegressor(BaseTransformer):
645
672
  Returns:
646
673
  Predicted dataset.
647
674
  """
648
- if False:
649
- self.fit(dataset)
650
- assert self._sklearn_object is not None
651
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
652
- return labels
653
- else:
654
- raise NotImplementedError
675
+ self.fit(dataset)
676
+ assert self._sklearn_object is not None
677
+ return self._sklearn_object.labels_
678
+
679
+
680
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
681
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
682
+ """
683
+ Returns:
684
+ Transformed dataset.
685
+ """
686
+ self.fit(dataset)
687
+ assert self._sklearn_object is not None
688
+ return self._sklearn_object.embedding_
689
+
655
690
 
656
691
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
657
692
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class Ridge(BaseTransformer):
58
70
  r"""Linear least squares with l2 regularization
59
71
  For more details on this class, see [sklearn.linear_model.Ridge]
@@ -222,7 +234,9 @@ class Ridge(BaseTransformer):
222
234
  self.set_label_cols(label_cols)
223
235
  self.set_passthrough_cols(passthrough_cols)
224
236
  self.set_drop_input_cols(drop_input_cols)
225
- self.set_sample_weight_col(sample_weight_col)
237
+ self.set_sample_weight_col(sample_weight_col)
238
+ self._use_external_memory_version = False
239
+ self._batch_size = -1
226
240
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
227
241
 
228
242
  self._deps = list(deps)
@@ -305,11 +319,6 @@ class Ridge(BaseTransformer):
305
319
  if isinstance(dataset, DataFrame):
306
320
  session = dataset._session
307
321
  assert session is not None # keep mypy happy
308
- # Validate that key package version in user workspace are supported in snowflake conda channel
309
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
310
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
311
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
312
-
313
322
  # Specify input columns so column pruning will be enforced
314
323
  selected_cols = self._get_active_columns()
315
324
  if len(selected_cols) > 0:
@@ -337,7 +346,9 @@ class Ridge(BaseTransformer):
337
346
  label_cols=self.label_cols,
338
347
  sample_weight_col=self.sample_weight_col,
339
348
  autogenerated=self._autogenerated,
340
- subproject=_SUBPROJECT
349
+ subproject=_SUBPROJECT,
350
+ use_external_memory_version=self._use_external_memory_version,
351
+ batch_size=self._batch_size,
341
352
  )
342
353
  self._sklearn_object = model_trainer.train()
343
354
  self._is_fitted = True
@@ -608,6 +619,22 @@ class Ridge(BaseTransformer):
608
619
  # each row containing a list of values.
609
620
  expected_dtype = "ARRAY"
610
621
 
622
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
623
+ if expected_dtype == "":
624
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
625
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
626
+ expected_dtype = "ARRAY"
627
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
628
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
629
+ expected_dtype = "ARRAY"
630
+ else:
631
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
632
+ # We can only infer the output types from the input types if the following two statemetns are true:
633
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
634
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
635
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
636
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
637
+
611
638
  output_df = self._batch_inference(
612
639
  dataset=dataset,
613
640
  inference_method="transform",
@@ -623,8 +650,8 @@ class Ridge(BaseTransformer):
623
650
 
624
651
  return output_df
625
652
 
626
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
627
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
653
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
654
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
628
655
  """ Method not supported for this class.
629
656
 
630
657
 
@@ -637,13 +664,21 @@ class Ridge(BaseTransformer):
637
664
  Returns:
638
665
  Predicted dataset.
639
666
  """
640
- if False:
641
- self.fit(dataset)
642
- assert self._sklearn_object is not None
643
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
644
- return labels
645
- else:
646
- raise NotImplementedError
667
+ self.fit(dataset)
668
+ assert self._sklearn_object is not None
669
+ return self._sklearn_object.labels_
670
+
671
+
672
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
673
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
674
+ """
675
+ Returns:
676
+ Transformed dataset.
677
+ """
678
+ self.fit(dataset)
679
+ assert self._sklearn_object is not None
680
+ return self._sklearn_object.embedding_
681
+
647
682
 
648
683
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
649
684
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RidgeClassifier(BaseTransformer):
58
70
  r"""Classifier using Ridge regression
59
71
  For more details on this class, see [sklearn.linear_model.RidgeClassifier]
@@ -221,7 +233,9 @@ class RidgeClassifier(BaseTransformer):
221
233
  self.set_label_cols(label_cols)
222
234
  self.set_passthrough_cols(passthrough_cols)
223
235
  self.set_drop_input_cols(drop_input_cols)
224
- self.set_sample_weight_col(sample_weight_col)
236
+ self.set_sample_weight_col(sample_weight_col)
237
+ self._use_external_memory_version = False
238
+ self._batch_size = -1
225
239
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
226
240
 
227
241
  self._deps = list(deps)
@@ -305,11 +319,6 @@ class RidgeClassifier(BaseTransformer):
305
319
  if isinstance(dataset, DataFrame):
306
320
  session = dataset._session
307
321
  assert session is not None # keep mypy happy
308
- # Validate that key package version in user workspace are supported in snowflake conda channel
309
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
310
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
311
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
312
-
313
322
  # Specify input columns so column pruning will be enforced
314
323
  selected_cols = self._get_active_columns()
315
324
  if len(selected_cols) > 0:
@@ -337,7 +346,9 @@ class RidgeClassifier(BaseTransformer):
337
346
  label_cols=self.label_cols,
338
347
  sample_weight_col=self.sample_weight_col,
339
348
  autogenerated=self._autogenerated,
340
- subproject=_SUBPROJECT
349
+ subproject=_SUBPROJECT,
350
+ use_external_memory_version=self._use_external_memory_version,
351
+ batch_size=self._batch_size,
341
352
  )
342
353
  self._sklearn_object = model_trainer.train()
343
354
  self._is_fitted = True
@@ -608,6 +619,22 @@ class RidgeClassifier(BaseTransformer):
608
619
  # each row containing a list of values.
609
620
  expected_dtype = "ARRAY"
610
621
 
622
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
623
+ if expected_dtype == "":
624
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
625
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
626
+ expected_dtype = "ARRAY"
627
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
628
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
629
+ expected_dtype = "ARRAY"
630
+ else:
631
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
632
+ # We can only infer the output types from the input types if the following two statemetns are true:
633
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
634
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
635
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
636
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
637
+
611
638
  output_df = self._batch_inference(
612
639
  dataset=dataset,
613
640
  inference_method="transform",
@@ -623,8 +650,8 @@ class RidgeClassifier(BaseTransformer):
623
650
 
624
651
  return output_df
625
652
 
626
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
627
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
653
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
654
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
628
655
  """ Method not supported for this class.
629
656
 
630
657
 
@@ -637,13 +664,21 @@ class RidgeClassifier(BaseTransformer):
637
664
  Returns:
638
665
  Predicted dataset.
639
666
  """
640
- if False:
641
- self.fit(dataset)
642
- assert self._sklearn_object is not None
643
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
644
- return labels
645
- else:
646
- raise NotImplementedError
667
+ self.fit(dataset)
668
+ assert self._sklearn_object is not None
669
+ return self._sklearn_object.labels_
670
+
671
+
672
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
673
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
674
+ """
675
+ Returns:
676
+ Transformed dataset.
677
+ """
678
+ self.fit(dataset)
679
+ assert self._sklearn_object is not None
680
+ return self._sklearn_object.embedding_
681
+
647
682
 
648
683
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
649
684
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RidgeClassifierCV(BaseTransformer):
58
70
  r"""Ridge classifier with built-in cross-validation
59
71
  For more details on this class, see [sklearn.linear_model.RidgeClassifierCV]
@@ -175,7 +187,9 @@ class RidgeClassifierCV(BaseTransformer):
175
187
  self.set_label_cols(label_cols)
176
188
  self.set_passthrough_cols(passthrough_cols)
177
189
  self.set_drop_input_cols(drop_input_cols)
178
- self.set_sample_weight_col(sample_weight_col)
190
+ self.set_sample_weight_col(sample_weight_col)
191
+ self._use_external_memory_version = False
192
+ self._batch_size = -1
179
193
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
180
194
 
181
195
  self._deps = list(deps)
@@ -256,11 +270,6 @@ class RidgeClassifierCV(BaseTransformer):
256
270
  if isinstance(dataset, DataFrame):
257
271
  session = dataset._session
258
272
  assert session is not None # keep mypy happy
259
- # Validate that key package version in user workspace are supported in snowflake conda channel
260
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
261
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
262
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
263
-
264
273
  # Specify input columns so column pruning will be enforced
265
274
  selected_cols = self._get_active_columns()
266
275
  if len(selected_cols) > 0:
@@ -288,7 +297,9 @@ class RidgeClassifierCV(BaseTransformer):
288
297
  label_cols=self.label_cols,
289
298
  sample_weight_col=self.sample_weight_col,
290
299
  autogenerated=self._autogenerated,
291
- subproject=_SUBPROJECT
300
+ subproject=_SUBPROJECT,
301
+ use_external_memory_version=self._use_external_memory_version,
302
+ batch_size=self._batch_size,
292
303
  )
293
304
  self._sklearn_object = model_trainer.train()
294
305
  self._is_fitted = True
@@ -559,6 +570,22 @@ class RidgeClassifierCV(BaseTransformer):
559
570
  # each row containing a list of values.
560
571
  expected_dtype = "ARRAY"
561
572
 
573
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
574
+ if expected_dtype == "":
575
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
576
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
577
+ expected_dtype = "ARRAY"
578
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
579
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
580
+ expected_dtype = "ARRAY"
581
+ else:
582
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
583
+ # We can only infer the output types from the input types if the following two statemetns are true:
584
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
585
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
586
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
587
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
588
+
562
589
  output_df = self._batch_inference(
563
590
  dataset=dataset,
564
591
  inference_method="transform",
@@ -574,8 +601,8 @@ class RidgeClassifierCV(BaseTransformer):
574
601
 
575
602
  return output_df
576
603
 
577
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
578
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
604
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
605
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
579
606
  """ Method not supported for this class.
580
607
 
581
608
 
@@ -588,13 +615,21 @@ class RidgeClassifierCV(BaseTransformer):
588
615
  Returns:
589
616
  Predicted dataset.
590
617
  """
591
- if False:
592
- self.fit(dataset)
593
- assert self._sklearn_object is not None
594
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
595
- return labels
596
- else:
597
- raise NotImplementedError
618
+ self.fit(dataset)
619
+ assert self._sklearn_object is not None
620
+ return self._sklearn_object.labels_
621
+
622
+
623
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
624
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
625
+ """
626
+ Returns:
627
+ Transformed dataset.
628
+ """
629
+ self.fit(dataset)
630
+ assert self._sklearn_object is not None
631
+ return self._sklearn_object.embedding_
632
+
598
633
 
599
634
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
600
635
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.