snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -53,6 +53,18 @@ _PROJECT = "ModelDevelopment"
53
53
  _SUBPROJECT = "".join([s.capitalize() for s in "lightgbm".replace("sklearn.", "").split("_")])
54
54
 
55
55
 
56
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
57
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
58
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
59
+ return check
60
+
61
+
62
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
+ return check
66
+
67
+
56
68
  class LGBMRegressor(BaseTransformer):
57
69
  r"""LightGBM regressor
58
70
  For more details on this class, see [lightgbm.LGBMRegressor]
@@ -144,7 +156,9 @@ class LGBMRegressor(BaseTransformer):
144
156
  self.set_label_cols(label_cols)
145
157
  self.set_passthrough_cols(passthrough_cols)
146
158
  self.set_drop_input_cols(drop_input_cols)
147
- self.set_sample_weight_col(sample_weight_col)
159
+ self.set_sample_weight_col(sample_weight_col)
160
+ self._use_external_memory_version = False
161
+ self._batch_size = -1
148
162
  deps: Set[str] = set([f'numpy=={np.__version__}', f'lightgbm=={lightgbm.__version__}', f'cloudpickle=={cp.__version__}'])
149
163
 
150
164
  self._deps = list(deps)
@@ -240,11 +254,6 @@ class LGBMRegressor(BaseTransformer):
240
254
  if isinstance(dataset, DataFrame):
241
255
  session = dataset._session
242
256
  assert session is not None # keep mypy happy
243
- # Validate that key package version in user workspace are supported in snowflake conda channel
244
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
245
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
246
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
247
-
248
257
  # Specify input columns so column pruning will be enforced
249
258
  selected_cols = self._get_active_columns()
250
259
  if len(selected_cols) > 0:
@@ -272,7 +281,9 @@ class LGBMRegressor(BaseTransformer):
272
281
  label_cols=self.label_cols,
273
282
  sample_weight_col=self.sample_weight_col,
274
283
  autogenerated=self._autogenerated,
275
- subproject=_SUBPROJECT
284
+ subproject=_SUBPROJECT,
285
+ use_external_memory_version=self._use_external_memory_version,
286
+ batch_size=self._batch_size,
276
287
  )
277
288
  self._sklearn_object = model_trainer.train()
278
289
  self._is_fitted = True
@@ -543,6 +554,22 @@ class LGBMRegressor(BaseTransformer):
543
554
  # each row containing a list of values.
544
555
  expected_dtype = "ARRAY"
545
556
 
557
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
558
+ if expected_dtype == "":
559
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
560
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
561
+ expected_dtype = "ARRAY"
562
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
563
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ else:
566
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
567
+ # We can only infer the output types from the input types if the following two statemetns are true:
568
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
569
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
570
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
571
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
572
+
546
573
  output_df = self._batch_inference(
547
574
  dataset=dataset,
548
575
  inference_method="transform",
@@ -558,8 +585,8 @@ class LGBMRegressor(BaseTransformer):
558
585
 
559
586
  return output_df
560
587
 
561
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
562
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
588
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
589
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
563
590
  """ Method not supported for this class.
564
591
 
565
592
 
@@ -572,13 +599,21 @@ class LGBMRegressor(BaseTransformer):
572
599
  Returns:
573
600
  Predicted dataset.
574
601
  """
575
- if False:
576
- self.fit(dataset)
577
- assert self._sklearn_object is not None
578
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
579
- return labels
580
- else:
581
- raise NotImplementedError
602
+ self.fit(dataset)
603
+ assert self._sklearn_object is not None
604
+ return self._sklearn_object.labels_
605
+
606
+
607
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
608
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
609
+ """
610
+ Returns:
611
+ Transformed dataset.
612
+ """
613
+ self.fit(dataset)
614
+ assert self._sklearn_object is not None
615
+ return self._sklearn_object.embedding_
616
+
582
617
 
583
618
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
584
619
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ARDRegression(BaseTransformer):
58
70
  r"""Bayesian ARD regression
59
71
  For more details on this class, see [sklearn.linear_model.ARDRegression]
@@ -179,7 +191,9 @@ class ARDRegression(BaseTransformer):
179
191
  self.set_label_cols(label_cols)
180
192
  self.set_passthrough_cols(passthrough_cols)
181
193
  self.set_drop_input_cols(drop_input_cols)
182
- self.set_sample_weight_col(sample_weight_col)
194
+ self.set_sample_weight_col(sample_weight_col)
195
+ self._use_external_memory_version = False
196
+ self._batch_size = -1
183
197
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
184
198
 
185
199
  self._deps = list(deps)
@@ -266,11 +280,6 @@ class ARDRegression(BaseTransformer):
266
280
  if isinstance(dataset, DataFrame):
267
281
  session = dataset._session
268
282
  assert session is not None # keep mypy happy
269
- # Validate that key package version in user workspace are supported in snowflake conda channel
270
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
271
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
272
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
273
-
274
283
  # Specify input columns so column pruning will be enforced
275
284
  selected_cols = self._get_active_columns()
276
285
  if len(selected_cols) > 0:
@@ -298,7 +307,9 @@ class ARDRegression(BaseTransformer):
298
307
  label_cols=self.label_cols,
299
308
  sample_weight_col=self.sample_weight_col,
300
309
  autogenerated=self._autogenerated,
301
- subproject=_SUBPROJECT
310
+ subproject=_SUBPROJECT,
311
+ use_external_memory_version=self._use_external_memory_version,
312
+ batch_size=self._batch_size,
302
313
  )
303
314
  self._sklearn_object = model_trainer.train()
304
315
  self._is_fitted = True
@@ -569,6 +580,22 @@ class ARDRegression(BaseTransformer):
569
580
  # each row containing a list of values.
570
581
  expected_dtype = "ARRAY"
571
582
 
583
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
584
+ if expected_dtype == "":
585
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
586
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
587
+ expected_dtype = "ARRAY"
588
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
589
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
590
+ expected_dtype = "ARRAY"
591
+ else:
592
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
593
+ # We can only infer the output types from the input types if the following two statemetns are true:
594
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
595
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
596
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
597
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
598
+
572
599
  output_df = self._batch_inference(
573
600
  dataset=dataset,
574
601
  inference_method="transform",
@@ -584,8 +611,8 @@ class ARDRegression(BaseTransformer):
584
611
 
585
612
  return output_df
586
613
 
587
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
588
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
614
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
615
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
589
616
  """ Method not supported for this class.
590
617
 
591
618
 
@@ -598,13 +625,21 @@ class ARDRegression(BaseTransformer):
598
625
  Returns:
599
626
  Predicted dataset.
600
627
  """
601
- if False:
602
- self.fit(dataset)
603
- assert self._sklearn_object is not None
604
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
605
- return labels
606
- else:
607
- raise NotImplementedError
628
+ self.fit(dataset)
629
+ assert self._sklearn_object is not None
630
+ return self._sklearn_object.labels_
631
+
632
+
633
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
634
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
635
+ """
636
+ Returns:
637
+ Transformed dataset.
638
+ """
639
+ self.fit(dataset)
640
+ assert self._sklearn_object is not None
641
+ return self._sklearn_object.embedding_
642
+
608
643
 
609
644
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
610
645
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class BayesianRidge(BaseTransformer):
58
70
  r"""Bayesian ridge regression
59
71
  For more details on this class, see [sklearn.linear_model.BayesianRidge]
@@ -189,7 +201,9 @@ class BayesianRidge(BaseTransformer):
189
201
  self.set_label_cols(label_cols)
190
202
  self.set_passthrough_cols(passthrough_cols)
191
203
  self.set_drop_input_cols(drop_input_cols)
192
- self.set_sample_weight_col(sample_weight_col)
204
+ self.set_sample_weight_col(sample_weight_col)
205
+ self._use_external_memory_version = False
206
+ self._batch_size = -1
193
207
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
194
208
 
195
209
  self._deps = list(deps)
@@ -277,11 +291,6 @@ class BayesianRidge(BaseTransformer):
277
291
  if isinstance(dataset, DataFrame):
278
292
  session = dataset._session
279
293
  assert session is not None # keep mypy happy
280
- # Validate that key package version in user workspace are supported in snowflake conda channel
281
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
282
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
283
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
284
-
285
294
  # Specify input columns so column pruning will be enforced
286
295
  selected_cols = self._get_active_columns()
287
296
  if len(selected_cols) > 0:
@@ -309,7 +318,9 @@ class BayesianRidge(BaseTransformer):
309
318
  label_cols=self.label_cols,
310
319
  sample_weight_col=self.sample_weight_col,
311
320
  autogenerated=self._autogenerated,
312
- subproject=_SUBPROJECT
321
+ subproject=_SUBPROJECT,
322
+ use_external_memory_version=self._use_external_memory_version,
323
+ batch_size=self._batch_size,
313
324
  )
314
325
  self._sklearn_object = model_trainer.train()
315
326
  self._is_fitted = True
@@ -580,6 +591,22 @@ class BayesianRidge(BaseTransformer):
580
591
  # each row containing a list of values.
581
592
  expected_dtype = "ARRAY"
582
593
 
594
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
595
+ if expected_dtype == "":
596
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
597
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
598
+ expected_dtype = "ARRAY"
599
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
600
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
601
+ expected_dtype = "ARRAY"
602
+ else:
603
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
604
+ # We can only infer the output types from the input types if the following two statemetns are true:
605
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
606
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
607
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
608
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
609
+
583
610
  output_df = self._batch_inference(
584
611
  dataset=dataset,
585
612
  inference_method="transform",
@@ -595,8 +622,8 @@ class BayesianRidge(BaseTransformer):
595
622
 
596
623
  return output_df
597
624
 
598
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
599
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
625
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
626
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
600
627
  """ Method not supported for this class.
601
628
 
602
629
 
@@ -609,13 +636,21 @@ class BayesianRidge(BaseTransformer):
609
636
  Returns:
610
637
  Predicted dataset.
611
638
  """
612
- if False:
613
- self.fit(dataset)
614
- assert self._sklearn_object is not None
615
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
616
- return labels
617
- else:
618
- raise NotImplementedError
639
+ self.fit(dataset)
640
+ assert self._sklearn_object is not None
641
+ return self._sklearn_object.labels_
642
+
643
+
644
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
645
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
646
+ """
647
+ Returns:
648
+ Transformed dataset.
649
+ """
650
+ self.fit(dataset)
651
+ assert self._sklearn_object is not None
652
+ return self._sklearn_object.embedding_
653
+
619
654
 
620
655
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
621
656
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ElasticNet(BaseTransformer):
58
70
  r"""Linear regression with combined L1 and L2 priors as regularizer
59
71
  For more details on this class, see [sklearn.linear_model.ElasticNet]
@@ -190,7 +202,9 @@ class ElasticNet(BaseTransformer):
190
202
  self.set_label_cols(label_cols)
191
203
  self.set_passthrough_cols(passthrough_cols)
192
204
  self.set_drop_input_cols(drop_input_cols)
193
- self.set_sample_weight_col(sample_weight_col)
205
+ self.set_sample_weight_col(sample_weight_col)
206
+ self._use_external_memory_version = False
207
+ self._batch_size = -1
194
208
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
195
209
 
196
210
  self._deps = list(deps)
@@ -276,11 +290,6 @@ class ElasticNet(BaseTransformer):
276
290
  if isinstance(dataset, DataFrame):
277
291
  session = dataset._session
278
292
  assert session is not None # keep mypy happy
279
- # Validate that key package version in user workspace are supported in snowflake conda channel
280
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
281
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
282
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
283
-
284
293
  # Specify input columns so column pruning will be enforced
285
294
  selected_cols = self._get_active_columns()
286
295
  if len(selected_cols) > 0:
@@ -308,7 +317,9 @@ class ElasticNet(BaseTransformer):
308
317
  label_cols=self.label_cols,
309
318
  sample_weight_col=self.sample_weight_col,
310
319
  autogenerated=self._autogenerated,
311
- subproject=_SUBPROJECT
320
+ subproject=_SUBPROJECT,
321
+ use_external_memory_version=self._use_external_memory_version,
322
+ batch_size=self._batch_size,
312
323
  )
313
324
  self._sklearn_object = model_trainer.train()
314
325
  self._is_fitted = True
@@ -579,6 +590,22 @@ class ElasticNet(BaseTransformer):
579
590
  # each row containing a list of values.
580
591
  expected_dtype = "ARRAY"
581
592
 
593
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
594
+ if expected_dtype == "":
595
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
596
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
597
+ expected_dtype = "ARRAY"
598
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
599
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
600
+ expected_dtype = "ARRAY"
601
+ else:
602
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
603
+ # We can only infer the output types from the input types if the following two statemetns are true:
604
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
605
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
606
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
607
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
608
+
582
609
  output_df = self._batch_inference(
583
610
  dataset=dataset,
584
611
  inference_method="transform",
@@ -594,8 +621,8 @@ class ElasticNet(BaseTransformer):
594
621
 
595
622
  return output_df
596
623
 
597
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
598
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
624
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
625
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
599
626
  """ Method not supported for this class.
600
627
 
601
628
 
@@ -608,13 +635,21 @@ class ElasticNet(BaseTransformer):
608
635
  Returns:
609
636
  Predicted dataset.
610
637
  """
611
- if False:
612
- self.fit(dataset)
613
- assert self._sklearn_object is not None
614
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
615
- return labels
616
- else:
617
- raise NotImplementedError
638
+ self.fit(dataset)
639
+ assert self._sklearn_object is not None
640
+ return self._sklearn_object.labels_
641
+
642
+
643
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
644
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
645
+ """
646
+ Returns:
647
+ Transformed dataset.
648
+ """
649
+ self.fit(dataset)
650
+ assert self._sklearn_object is not None
651
+ return self._sklearn_object.embedding_
652
+
618
653
 
619
654
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
620
655
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ElasticNetCV(BaseTransformer):
58
70
  r"""Elastic Net model with iterative fitting along a regularization path
59
71
  For more details on this class, see [sklearn.linear_model.ElasticNetCV]
@@ -222,7 +234,9 @@ class ElasticNetCV(BaseTransformer):
222
234
  self.set_label_cols(label_cols)
223
235
  self.set_passthrough_cols(passthrough_cols)
224
236
  self.set_drop_input_cols(drop_input_cols)
225
- self.set_sample_weight_col(sample_weight_col)
237
+ self.set_sample_weight_col(sample_weight_col)
238
+ self._use_external_memory_version = False
239
+ self._batch_size = -1
226
240
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
227
241
 
228
242
  self._deps = list(deps)
@@ -312,11 +326,6 @@ class ElasticNetCV(BaseTransformer):
312
326
  if isinstance(dataset, DataFrame):
313
327
  session = dataset._session
314
328
  assert session is not None # keep mypy happy
315
- # Validate that key package version in user workspace are supported in snowflake conda channel
316
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
317
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
318
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
319
-
320
329
  # Specify input columns so column pruning will be enforced
321
330
  selected_cols = self._get_active_columns()
322
331
  if len(selected_cols) > 0:
@@ -344,7 +353,9 @@ class ElasticNetCV(BaseTransformer):
344
353
  label_cols=self.label_cols,
345
354
  sample_weight_col=self.sample_weight_col,
346
355
  autogenerated=self._autogenerated,
347
- subproject=_SUBPROJECT
356
+ subproject=_SUBPROJECT,
357
+ use_external_memory_version=self._use_external_memory_version,
358
+ batch_size=self._batch_size,
348
359
  )
349
360
  self._sklearn_object = model_trainer.train()
350
361
  self._is_fitted = True
@@ -615,6 +626,22 @@ class ElasticNetCV(BaseTransformer):
615
626
  # each row containing a list of values.
616
627
  expected_dtype = "ARRAY"
617
628
 
629
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
630
+ if expected_dtype == "":
631
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
632
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
633
+ expected_dtype = "ARRAY"
634
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
635
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
636
+ expected_dtype = "ARRAY"
637
+ else:
638
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
639
+ # We can only infer the output types from the input types if the following two statemetns are true:
640
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
641
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
642
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
643
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
644
+
618
645
  output_df = self._batch_inference(
619
646
  dataset=dataset,
620
647
  inference_method="transform",
@@ -630,8 +657,8 @@ class ElasticNetCV(BaseTransformer):
630
657
 
631
658
  return output_df
632
659
 
633
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
634
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
660
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
661
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
635
662
  """ Method not supported for this class.
636
663
 
637
664
 
@@ -644,13 +671,21 @@ class ElasticNetCV(BaseTransformer):
644
671
  Returns:
645
672
  Predicted dataset.
646
673
  """
647
- if False:
648
- self.fit(dataset)
649
- assert self._sklearn_object is not None
650
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
651
- return labels
652
- else:
653
- raise NotImplementedError
674
+ self.fit(dataset)
675
+ assert self._sklearn_object is not None
676
+ return self._sklearn_object.labels_
677
+
678
+
679
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
680
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
681
+ """
682
+ Returns:
683
+ Transformed dataset.
684
+ """
685
+ self.fit(dataset)
686
+ assert self._sklearn_object is not None
687
+ return self._sklearn_object.embedding_
688
+
654
689
 
655
690
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
656
691
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.