snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LassoCV(BaseTransformer):
58
70
  r"""Lasso linear model with iterative fitting along a regularization path
59
71
  For more details on this class, see [sklearn.linear_model.LassoCV]
@@ -209,7 +221,9 @@ class LassoCV(BaseTransformer):
209
221
  self.set_label_cols(label_cols)
210
222
  self.set_passthrough_cols(passthrough_cols)
211
223
  self.set_drop_input_cols(drop_input_cols)
212
- self.set_sample_weight_col(sample_weight_col)
224
+ self.set_sample_weight_col(sample_weight_col)
225
+ self._use_external_memory_version = False
226
+ self._batch_size = -1
213
227
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
214
228
 
215
229
  self._deps = list(deps)
@@ -298,11 +312,6 @@ class LassoCV(BaseTransformer):
298
312
  if isinstance(dataset, DataFrame):
299
313
  session = dataset._session
300
314
  assert session is not None # keep mypy happy
301
- # Validate that key package version in user workspace are supported in snowflake conda channel
302
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
303
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
304
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
305
-
306
315
  # Specify input columns so column pruning will be enforced
307
316
  selected_cols = self._get_active_columns()
308
317
  if len(selected_cols) > 0:
@@ -330,7 +339,9 @@ class LassoCV(BaseTransformer):
330
339
  label_cols=self.label_cols,
331
340
  sample_weight_col=self.sample_weight_col,
332
341
  autogenerated=self._autogenerated,
333
- subproject=_SUBPROJECT
342
+ subproject=_SUBPROJECT,
343
+ use_external_memory_version=self._use_external_memory_version,
344
+ batch_size=self._batch_size,
334
345
  )
335
346
  self._sklearn_object = model_trainer.train()
336
347
  self._is_fitted = True
@@ -601,6 +612,22 @@ class LassoCV(BaseTransformer):
601
612
  # each row containing a list of values.
602
613
  expected_dtype = "ARRAY"
603
614
 
615
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
616
+ if expected_dtype == "":
617
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
618
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
619
+ expected_dtype = "ARRAY"
620
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
621
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
622
+ expected_dtype = "ARRAY"
623
+ else:
624
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
625
+ # We can only infer the output types from the input types if the following two statemetns are true:
626
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
627
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
628
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
629
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
630
+
604
631
  output_df = self._batch_inference(
605
632
  dataset=dataset,
606
633
  inference_method="transform",
@@ -616,8 +643,8 @@ class LassoCV(BaseTransformer):
616
643
 
617
644
  return output_df
618
645
 
619
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
620
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
646
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
647
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
621
648
  """ Method not supported for this class.
622
649
 
623
650
 
@@ -630,13 +657,21 @@ class LassoCV(BaseTransformer):
630
657
  Returns:
631
658
  Predicted dataset.
632
659
  """
633
- if False:
634
- self.fit(dataset)
635
- assert self._sklearn_object is not None
636
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
637
- return labels
638
- else:
639
- raise NotImplementedError
660
+ self.fit(dataset)
661
+ assert self._sklearn_object is not None
662
+ return self._sklearn_object.labels_
663
+
664
+
665
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
666
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
667
+ """
668
+ Returns:
669
+ Transformed dataset.
670
+ """
671
+ self.fit(dataset)
672
+ assert self._sklearn_object is not None
673
+ return self._sklearn_object.embedding_
674
+
640
675
 
641
676
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
642
677
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LassoLars(BaseTransformer):
58
70
  r"""Lasso model fit with Least Angle Regression a
59
71
  For more details on this class, see [sklearn.linear_model.LassoLars]
@@ -203,7 +215,9 @@ class LassoLars(BaseTransformer):
203
215
  self.set_label_cols(label_cols)
204
216
  self.set_passthrough_cols(passthrough_cols)
205
217
  self.set_drop_input_cols(drop_input_cols)
206
- self.set_sample_weight_col(sample_weight_col)
218
+ self.set_sample_weight_col(sample_weight_col)
219
+ self._use_external_memory_version = False
220
+ self._batch_size = -1
207
221
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
208
222
 
209
223
  self._deps = list(deps)
@@ -290,11 +304,6 @@ class LassoLars(BaseTransformer):
290
304
  if isinstance(dataset, DataFrame):
291
305
  session = dataset._session
292
306
  assert session is not None # keep mypy happy
293
- # Validate that key package version in user workspace are supported in snowflake conda channel
294
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
295
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
296
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
297
-
298
307
  # Specify input columns so column pruning will be enforced
299
308
  selected_cols = self._get_active_columns()
300
309
  if len(selected_cols) > 0:
@@ -322,7 +331,9 @@ class LassoLars(BaseTransformer):
322
331
  label_cols=self.label_cols,
323
332
  sample_weight_col=self.sample_weight_col,
324
333
  autogenerated=self._autogenerated,
325
- subproject=_SUBPROJECT
334
+ subproject=_SUBPROJECT,
335
+ use_external_memory_version=self._use_external_memory_version,
336
+ batch_size=self._batch_size,
326
337
  )
327
338
  self._sklearn_object = model_trainer.train()
328
339
  self._is_fitted = True
@@ -593,6 +604,22 @@ class LassoLars(BaseTransformer):
593
604
  # each row containing a list of values.
594
605
  expected_dtype = "ARRAY"
595
606
 
607
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
608
+ if expected_dtype == "":
609
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
610
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
611
+ expected_dtype = "ARRAY"
612
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
613
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
614
+ expected_dtype = "ARRAY"
615
+ else:
616
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
617
+ # We can only infer the output types from the input types if the following two statemetns are true:
618
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
619
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
620
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
621
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
622
+
596
623
  output_df = self._batch_inference(
597
624
  dataset=dataset,
598
625
  inference_method="transform",
@@ -608,8 +635,8 @@ class LassoLars(BaseTransformer):
608
635
 
609
636
  return output_df
610
637
 
611
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
612
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
638
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
639
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
613
640
  """ Method not supported for this class.
614
641
 
615
642
 
@@ -622,13 +649,21 @@ class LassoLars(BaseTransformer):
622
649
  Returns:
623
650
  Predicted dataset.
624
651
  """
625
- if False:
626
- self.fit(dataset)
627
- assert self._sklearn_object is not None
628
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
629
- return labels
630
- else:
631
- raise NotImplementedError
652
+ self.fit(dataset)
653
+ assert self._sklearn_object is not None
654
+ return self._sklearn_object.labels_
655
+
656
+
657
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
658
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
659
+ """
660
+ Returns:
661
+ Transformed dataset.
662
+ """
663
+ self.fit(dataset)
664
+ assert self._sklearn_object is not None
665
+ return self._sklearn_object.embedding_
666
+
632
667
 
633
668
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
634
669
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LassoLarsCV(BaseTransformer):
58
70
  r"""Cross-validated Lasso, using the LARS algorithm
59
71
  For more details on this class, see [sklearn.linear_model.LassoLarsCV]
@@ -205,7 +217,9 @@ class LassoLarsCV(BaseTransformer):
205
217
  self.set_label_cols(label_cols)
206
218
  self.set_passthrough_cols(passthrough_cols)
207
219
  self.set_drop_input_cols(drop_input_cols)
208
- self.set_sample_weight_col(sample_weight_col)
220
+ self.set_sample_weight_col(sample_weight_col)
221
+ self._use_external_memory_version = False
222
+ self._batch_size = -1
209
223
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
210
224
 
211
225
  self._deps = list(deps)
@@ -291,11 +305,6 @@ class LassoLarsCV(BaseTransformer):
291
305
  if isinstance(dataset, DataFrame):
292
306
  session = dataset._session
293
307
  assert session is not None # keep mypy happy
294
- # Validate that key package version in user workspace are supported in snowflake conda channel
295
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
296
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
297
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
298
-
299
308
  # Specify input columns so column pruning will be enforced
300
309
  selected_cols = self._get_active_columns()
301
310
  if len(selected_cols) > 0:
@@ -323,7 +332,9 @@ class LassoLarsCV(BaseTransformer):
323
332
  label_cols=self.label_cols,
324
333
  sample_weight_col=self.sample_weight_col,
325
334
  autogenerated=self._autogenerated,
326
- subproject=_SUBPROJECT
335
+ subproject=_SUBPROJECT,
336
+ use_external_memory_version=self._use_external_memory_version,
337
+ batch_size=self._batch_size,
327
338
  )
328
339
  self._sklearn_object = model_trainer.train()
329
340
  self._is_fitted = True
@@ -594,6 +605,22 @@ class LassoLarsCV(BaseTransformer):
594
605
  # each row containing a list of values.
595
606
  expected_dtype = "ARRAY"
596
607
 
608
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
609
+ if expected_dtype == "":
610
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
611
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
612
+ expected_dtype = "ARRAY"
613
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
614
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
615
+ expected_dtype = "ARRAY"
616
+ else:
617
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
618
+ # We can only infer the output types from the input types if the following two statemetns are true:
619
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
620
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
621
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
622
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
623
+
597
624
  output_df = self._batch_inference(
598
625
  dataset=dataset,
599
626
  inference_method="transform",
@@ -609,8 +636,8 @@ class LassoLarsCV(BaseTransformer):
609
636
 
610
637
  return output_df
611
638
 
612
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
613
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
639
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
640
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
614
641
  """ Method not supported for this class.
615
642
 
616
643
 
@@ -623,13 +650,21 @@ class LassoLarsCV(BaseTransformer):
623
650
  Returns:
624
651
  Predicted dataset.
625
652
  """
626
- if False:
627
- self.fit(dataset)
628
- assert self._sklearn_object is not None
629
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
630
- return labels
631
- else:
632
- raise NotImplementedError
653
+ self.fit(dataset)
654
+ assert self._sklearn_object is not None
655
+ return self._sklearn_object.labels_
656
+
657
+
658
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
659
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
660
+ """
661
+ Returns:
662
+ Transformed dataset.
663
+ """
664
+ self.fit(dataset)
665
+ assert self._sklearn_object is not None
666
+ return self._sklearn_object.embedding_
667
+
633
668
 
634
669
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
635
670
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LassoLarsIC(BaseTransformer):
58
70
  r"""Lasso model fit with Lars using BIC or AIC for model selection
59
71
  For more details on this class, see [sklearn.linear_model.LassoLarsIC]
@@ -189,7 +201,9 @@ class LassoLarsIC(BaseTransformer):
189
201
  self.set_label_cols(label_cols)
190
202
  self.set_passthrough_cols(passthrough_cols)
191
203
  self.set_drop_input_cols(drop_input_cols)
192
- self.set_sample_weight_col(sample_weight_col)
204
+ self.set_sample_weight_col(sample_weight_col)
205
+ self._use_external_memory_version = False
206
+ self._batch_size = -1
193
207
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
194
208
 
195
209
  self._deps = list(deps)
@@ -274,11 +288,6 @@ class LassoLarsIC(BaseTransformer):
274
288
  if isinstance(dataset, DataFrame):
275
289
  session = dataset._session
276
290
  assert session is not None # keep mypy happy
277
- # Validate that key package version in user workspace are supported in snowflake conda channel
278
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
279
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
280
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
281
-
282
291
  # Specify input columns so column pruning will be enforced
283
292
  selected_cols = self._get_active_columns()
284
293
  if len(selected_cols) > 0:
@@ -306,7 +315,9 @@ class LassoLarsIC(BaseTransformer):
306
315
  label_cols=self.label_cols,
307
316
  sample_weight_col=self.sample_weight_col,
308
317
  autogenerated=self._autogenerated,
309
- subproject=_SUBPROJECT
318
+ subproject=_SUBPROJECT,
319
+ use_external_memory_version=self._use_external_memory_version,
320
+ batch_size=self._batch_size,
310
321
  )
311
322
  self._sklearn_object = model_trainer.train()
312
323
  self._is_fitted = True
@@ -577,6 +588,22 @@ class LassoLarsIC(BaseTransformer):
577
588
  # each row containing a list of values.
578
589
  expected_dtype = "ARRAY"
579
590
 
591
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
592
+ if expected_dtype == "":
593
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
594
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
595
+ expected_dtype = "ARRAY"
596
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
597
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
598
+ expected_dtype = "ARRAY"
599
+ else:
600
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
601
+ # We can only infer the output types from the input types if the following two statemetns are true:
602
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
603
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
604
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
605
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
606
+
580
607
  output_df = self._batch_inference(
581
608
  dataset=dataset,
582
609
  inference_method="transform",
@@ -592,8 +619,8 @@ class LassoLarsIC(BaseTransformer):
592
619
 
593
620
  return output_df
594
621
 
595
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
596
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
622
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
623
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
597
624
  """ Method not supported for this class.
598
625
 
599
626
 
@@ -606,13 +633,21 @@ class LassoLarsIC(BaseTransformer):
606
633
  Returns:
607
634
  Predicted dataset.
608
635
  """
609
- if False:
610
- self.fit(dataset)
611
- assert self._sklearn_object is not None
612
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
613
- return labels
614
- else:
615
- raise NotImplementedError
636
+ self.fit(dataset)
637
+ assert self._sklearn_object is not None
638
+ return self._sklearn_object.labels_
639
+
640
+
641
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
642
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
643
+ """
644
+ Returns:
645
+ Transformed dataset.
646
+ """
647
+ self.fit(dataset)
648
+ assert self._sklearn_object is not None
649
+ return self._sklearn_object.embedding_
650
+
616
651
 
617
652
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
618
653
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LinearRegression(BaseTransformer):
58
70
  r"""Ordinary least squares Linear Regression
59
71
  For more details on this class, see [sklearn.linear_model.LinearRegression]
@@ -148,7 +160,9 @@ class LinearRegression(BaseTransformer):
148
160
  self.set_label_cols(label_cols)
149
161
  self.set_passthrough_cols(passthrough_cols)
150
162
  self.set_drop_input_cols(drop_input_cols)
151
- self.set_sample_weight_col(sample_weight_col)
163
+ self.set_sample_weight_col(sample_weight_col)
164
+ self._use_external_memory_version = False
165
+ self._batch_size = -1
152
166
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
153
167
 
154
168
  self._deps = list(deps)
@@ -227,11 +241,6 @@ class LinearRegression(BaseTransformer):
227
241
  if isinstance(dataset, DataFrame):
228
242
  session = dataset._session
229
243
  assert session is not None # keep mypy happy
230
- # Validate that key package version in user workspace are supported in snowflake conda channel
231
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
232
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
233
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
234
-
235
244
  # Specify input columns so column pruning will be enforced
236
245
  selected_cols = self._get_active_columns()
237
246
  if len(selected_cols) > 0:
@@ -259,7 +268,9 @@ class LinearRegression(BaseTransformer):
259
268
  label_cols=self.label_cols,
260
269
  sample_weight_col=self.sample_weight_col,
261
270
  autogenerated=self._autogenerated,
262
- subproject=_SUBPROJECT
271
+ subproject=_SUBPROJECT,
272
+ use_external_memory_version=self._use_external_memory_version,
273
+ batch_size=self._batch_size,
263
274
  )
264
275
  self._sklearn_object = model_trainer.train()
265
276
  self._is_fitted = True
@@ -530,6 +541,22 @@ class LinearRegression(BaseTransformer):
530
541
  # each row containing a list of values.
531
542
  expected_dtype = "ARRAY"
532
543
 
544
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
545
+ if expected_dtype == "":
546
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
547
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
548
+ expected_dtype = "ARRAY"
549
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
550
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ else:
553
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
554
+ # We can only infer the output types from the input types if the following two statemetns are true:
555
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
556
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
557
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
558
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
559
+
533
560
  output_df = self._batch_inference(
534
561
  dataset=dataset,
535
562
  inference_method="transform",
@@ -545,8 +572,8 @@ class LinearRegression(BaseTransformer):
545
572
 
546
573
  return output_df
547
574
 
548
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
549
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
575
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
576
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
550
577
  """ Method not supported for this class.
551
578
 
552
579
 
@@ -559,13 +586,21 @@ class LinearRegression(BaseTransformer):
559
586
  Returns:
560
587
  Predicted dataset.
561
588
  """
562
- if False:
563
- self.fit(dataset)
564
- assert self._sklearn_object is not None
565
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
566
- return labels
567
- else:
568
- raise NotImplementedError
589
+ self.fit(dataset)
590
+ assert self._sklearn_object is not None
591
+ return self._sklearn_object.labels_
592
+
593
+
594
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
595
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
596
+ """
597
+ Returns:
598
+ Transformed dataset.
599
+ """
600
+ self.fit(dataset)
601
+ assert self._sklearn_object is not None
602
+ return self._sklearn_object.embedding_
603
+
569
604
 
570
605
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
571
606
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.