snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LogisticRegression(BaseTransformer):
58
70
  r"""Logistic Regression (aka logit, MaxEnt) classifier
59
71
  For more details on this class, see [sklearn.linear_model.LogisticRegression]
@@ -251,7 +263,9 @@ class LogisticRegression(BaseTransformer):
251
263
  self.set_label_cols(label_cols)
252
264
  self.set_passthrough_cols(passthrough_cols)
253
265
  self.set_drop_input_cols(drop_input_cols)
254
- self.set_sample_weight_col(sample_weight_col)
266
+ self.set_sample_weight_col(sample_weight_col)
267
+ self._use_external_memory_version = False
268
+ self._batch_size = -1
255
269
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
256
270
 
257
271
  self._deps = list(deps)
@@ -341,11 +355,6 @@ class LogisticRegression(BaseTransformer):
341
355
  if isinstance(dataset, DataFrame):
342
356
  session = dataset._session
343
357
  assert session is not None # keep mypy happy
344
- # Validate that key package version in user workspace are supported in snowflake conda channel
345
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
346
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
347
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
348
-
349
358
  # Specify input columns so column pruning will be enforced
350
359
  selected_cols = self._get_active_columns()
351
360
  if len(selected_cols) > 0:
@@ -373,7 +382,9 @@ class LogisticRegression(BaseTransformer):
373
382
  label_cols=self.label_cols,
374
383
  sample_weight_col=self.sample_weight_col,
375
384
  autogenerated=self._autogenerated,
376
- subproject=_SUBPROJECT
385
+ subproject=_SUBPROJECT,
386
+ use_external_memory_version=self._use_external_memory_version,
387
+ batch_size=self._batch_size,
377
388
  )
378
389
  self._sklearn_object = model_trainer.train()
379
390
  self._is_fitted = True
@@ -644,6 +655,22 @@ class LogisticRegression(BaseTransformer):
644
655
  # each row containing a list of values.
645
656
  expected_dtype = "ARRAY"
646
657
 
658
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
659
+ if expected_dtype == "":
660
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
661
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
662
+ expected_dtype = "ARRAY"
663
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
664
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
665
+ expected_dtype = "ARRAY"
666
+ else:
667
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
668
+ # We can only infer the output types from the input types if the following two statemetns are true:
669
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
670
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
671
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
672
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
673
+
647
674
  output_df = self._batch_inference(
648
675
  dataset=dataset,
649
676
  inference_method="transform",
@@ -659,8 +686,8 @@ class LogisticRegression(BaseTransformer):
659
686
 
660
687
  return output_df
661
688
 
662
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
663
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
689
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
690
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
664
691
  """ Method not supported for this class.
665
692
 
666
693
 
@@ -673,13 +700,21 @@ class LogisticRegression(BaseTransformer):
673
700
  Returns:
674
701
  Predicted dataset.
675
702
  """
676
- if False:
677
- self.fit(dataset)
678
- assert self._sklearn_object is not None
679
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
680
- return labels
681
- else:
682
- raise NotImplementedError
703
+ self.fit(dataset)
704
+ assert self._sklearn_object is not None
705
+ return self._sklearn_object.labels_
706
+
707
+
708
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
709
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
710
+ """
711
+ Returns:
712
+ Transformed dataset.
713
+ """
714
+ self.fit(dataset)
715
+ assert self._sklearn_object is not None
716
+ return self._sklearn_object.embedding_
717
+
683
718
 
684
719
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
685
720
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LogisticRegressionCV(BaseTransformer):
58
70
  r"""Logistic Regression CV (aka logit, MaxEnt) classifier
59
71
  For more details on this class, see [sklearn.linear_model.LogisticRegressionCV]
@@ -270,7 +282,9 @@ class LogisticRegressionCV(BaseTransformer):
270
282
  self.set_label_cols(label_cols)
271
283
  self.set_passthrough_cols(passthrough_cols)
272
284
  self.set_drop_input_cols(drop_input_cols)
273
- self.set_sample_weight_col(sample_weight_col)
285
+ self.set_sample_weight_col(sample_weight_col)
286
+ self._use_external_memory_version = False
287
+ self._batch_size = -1
274
288
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
275
289
 
276
290
  self._deps = list(deps)
@@ -362,11 +376,6 @@ class LogisticRegressionCV(BaseTransformer):
362
376
  if isinstance(dataset, DataFrame):
363
377
  session = dataset._session
364
378
  assert session is not None # keep mypy happy
365
- # Validate that key package version in user workspace are supported in snowflake conda channel
366
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
367
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
368
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
369
-
370
379
  # Specify input columns so column pruning will be enforced
371
380
  selected_cols = self._get_active_columns()
372
381
  if len(selected_cols) > 0:
@@ -394,7 +403,9 @@ class LogisticRegressionCV(BaseTransformer):
394
403
  label_cols=self.label_cols,
395
404
  sample_weight_col=self.sample_weight_col,
396
405
  autogenerated=self._autogenerated,
397
- subproject=_SUBPROJECT
406
+ subproject=_SUBPROJECT,
407
+ use_external_memory_version=self._use_external_memory_version,
408
+ batch_size=self._batch_size,
398
409
  )
399
410
  self._sklearn_object = model_trainer.train()
400
411
  self._is_fitted = True
@@ -665,6 +676,22 @@ class LogisticRegressionCV(BaseTransformer):
665
676
  # each row containing a list of values.
666
677
  expected_dtype = "ARRAY"
667
678
 
679
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
680
+ if expected_dtype == "":
681
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
682
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
683
+ expected_dtype = "ARRAY"
684
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
685
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
686
+ expected_dtype = "ARRAY"
687
+ else:
688
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
689
+ # We can only infer the output types from the input types if the following two statemetns are true:
690
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
691
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
692
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
693
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
694
+
668
695
  output_df = self._batch_inference(
669
696
  dataset=dataset,
670
697
  inference_method="transform",
@@ -680,8 +707,8 @@ class LogisticRegressionCV(BaseTransformer):
680
707
 
681
708
  return output_df
682
709
 
683
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
684
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
710
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
711
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
685
712
  """ Method not supported for this class.
686
713
 
687
714
 
@@ -694,13 +721,21 @@ class LogisticRegressionCV(BaseTransformer):
694
721
  Returns:
695
722
  Predicted dataset.
696
723
  """
697
- if False:
698
- self.fit(dataset)
699
- assert self._sklearn_object is not None
700
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
701
- return labels
702
- else:
703
- raise NotImplementedError
724
+ self.fit(dataset)
725
+ assert self._sklearn_object is not None
726
+ return self._sklearn_object.labels_
727
+
728
+
729
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
730
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
731
+ """
732
+ Returns:
733
+ Transformed dataset.
734
+ """
735
+ self.fit(dataset)
736
+ assert self._sklearn_object is not None
737
+ return self._sklearn_object.embedding_
738
+
704
739
 
705
740
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
706
741
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MultiTaskElasticNet(BaseTransformer):
58
70
  r"""Multi-task ElasticNet model trained with L1/L2 mixed-norm as regularizer
59
71
  For more details on this class, see [sklearn.linear_model.MultiTaskElasticNet]
@@ -176,7 +188,9 @@ class MultiTaskElasticNet(BaseTransformer):
176
188
  self.set_label_cols(label_cols)
177
189
  self.set_passthrough_cols(passthrough_cols)
178
190
  self.set_drop_input_cols(drop_input_cols)
179
- self.set_sample_weight_col(sample_weight_col)
191
+ self.set_sample_weight_col(sample_weight_col)
192
+ self._use_external_memory_version = False
193
+ self._batch_size = -1
180
194
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
181
195
 
182
196
  self._deps = list(deps)
@@ -260,11 +274,6 @@ class MultiTaskElasticNet(BaseTransformer):
260
274
  if isinstance(dataset, DataFrame):
261
275
  session = dataset._session
262
276
  assert session is not None # keep mypy happy
263
- # Validate that key package version in user workspace are supported in snowflake conda channel
264
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
265
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
266
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
267
-
268
277
  # Specify input columns so column pruning will be enforced
269
278
  selected_cols = self._get_active_columns()
270
279
  if len(selected_cols) > 0:
@@ -292,7 +301,9 @@ class MultiTaskElasticNet(BaseTransformer):
292
301
  label_cols=self.label_cols,
293
302
  sample_weight_col=self.sample_weight_col,
294
303
  autogenerated=self._autogenerated,
295
- subproject=_SUBPROJECT
304
+ subproject=_SUBPROJECT,
305
+ use_external_memory_version=self._use_external_memory_version,
306
+ batch_size=self._batch_size,
296
307
  )
297
308
  self._sklearn_object = model_trainer.train()
298
309
  self._is_fitted = True
@@ -563,6 +574,22 @@ class MultiTaskElasticNet(BaseTransformer):
563
574
  # each row containing a list of values.
564
575
  expected_dtype = "ARRAY"
565
576
 
577
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
578
+ if expected_dtype == "":
579
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
580
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
581
+ expected_dtype = "ARRAY"
582
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
583
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
584
+ expected_dtype = "ARRAY"
585
+ else:
586
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
587
+ # We can only infer the output types from the input types if the following two statemetns are true:
588
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
589
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
590
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
591
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
592
+
566
593
  output_df = self._batch_inference(
567
594
  dataset=dataset,
568
595
  inference_method="transform",
@@ -578,8 +605,8 @@ class MultiTaskElasticNet(BaseTransformer):
578
605
 
579
606
  return output_df
580
607
 
581
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
582
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
608
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
609
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
583
610
  """ Method not supported for this class.
584
611
 
585
612
 
@@ -592,13 +619,21 @@ class MultiTaskElasticNet(BaseTransformer):
592
619
  Returns:
593
620
  Predicted dataset.
594
621
  """
595
- if False:
596
- self.fit(dataset)
597
- assert self._sklearn_object is not None
598
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
599
- return labels
600
- else:
601
- raise NotImplementedError
622
+ self.fit(dataset)
623
+ assert self._sklearn_object is not None
624
+ return self._sklearn_object.labels_
625
+
626
+
627
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
628
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
629
+ """
630
+ Returns:
631
+ Transformed dataset.
632
+ """
633
+ self.fit(dataset)
634
+ assert self._sklearn_object is not None
635
+ return self._sklearn_object.embedding_
636
+
602
637
 
603
638
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
604
639
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MultiTaskElasticNetCV(BaseTransformer):
58
70
  r"""Multi-task L1/L2 ElasticNet with built-in cross-validation
59
71
  For more details on this class, see [sklearn.linear_model.MultiTaskElasticNetCV]
@@ -213,7 +225,9 @@ class MultiTaskElasticNetCV(BaseTransformer):
213
225
  self.set_label_cols(label_cols)
214
226
  self.set_passthrough_cols(passthrough_cols)
215
227
  self.set_drop_input_cols(drop_input_cols)
216
- self.set_sample_weight_col(sample_weight_col)
228
+ self.set_sample_weight_col(sample_weight_col)
229
+ self._use_external_memory_version = False
230
+ self._batch_size = -1
217
231
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
218
232
 
219
233
  self._deps = list(deps)
@@ -301,11 +315,6 @@ class MultiTaskElasticNetCV(BaseTransformer):
301
315
  if isinstance(dataset, DataFrame):
302
316
  session = dataset._session
303
317
  assert session is not None # keep mypy happy
304
- # Validate that key package version in user workspace are supported in snowflake conda channel
305
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
306
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
307
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
308
-
309
318
  # Specify input columns so column pruning will be enforced
310
319
  selected_cols = self._get_active_columns()
311
320
  if len(selected_cols) > 0:
@@ -333,7 +342,9 @@ class MultiTaskElasticNetCV(BaseTransformer):
333
342
  label_cols=self.label_cols,
334
343
  sample_weight_col=self.sample_weight_col,
335
344
  autogenerated=self._autogenerated,
336
- subproject=_SUBPROJECT
345
+ subproject=_SUBPROJECT,
346
+ use_external_memory_version=self._use_external_memory_version,
347
+ batch_size=self._batch_size,
337
348
  )
338
349
  self._sklearn_object = model_trainer.train()
339
350
  self._is_fitted = True
@@ -604,6 +615,22 @@ class MultiTaskElasticNetCV(BaseTransformer):
604
615
  # each row containing a list of values.
605
616
  expected_dtype = "ARRAY"
606
617
 
618
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
619
+ if expected_dtype == "":
620
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
621
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
622
+ expected_dtype = "ARRAY"
623
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
624
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
625
+ expected_dtype = "ARRAY"
626
+ else:
627
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
628
+ # We can only infer the output types from the input types if the following two statemetns are true:
629
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
630
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
631
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
632
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
633
+
607
634
  output_df = self._batch_inference(
608
635
  dataset=dataset,
609
636
  inference_method="transform",
@@ -619,8 +646,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
619
646
 
620
647
  return output_df
621
648
 
622
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
623
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
649
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
650
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
624
651
  """ Method not supported for this class.
625
652
 
626
653
 
@@ -633,13 +660,21 @@ class MultiTaskElasticNetCV(BaseTransformer):
633
660
  Returns:
634
661
  Predicted dataset.
635
662
  """
636
- if False:
637
- self.fit(dataset)
638
- assert self._sklearn_object is not None
639
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
640
- return labels
641
- else:
642
- raise NotImplementedError
663
+ self.fit(dataset)
664
+ assert self._sklearn_object is not None
665
+ return self._sklearn_object.labels_
666
+
667
+
668
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
669
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
670
+ """
671
+ Returns:
672
+ Transformed dataset.
673
+ """
674
+ self.fit(dataset)
675
+ assert self._sklearn_object is not None
676
+ return self._sklearn_object.embedding_
677
+
643
678
 
644
679
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
645
680
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MultiTaskLasso(BaseTransformer):
58
70
  r"""Multi-task Lasso model trained with L1/L2 mixed-norm as regularizer
59
71
  For more details on this class, see [sklearn.linear_model.MultiTaskLasso]
@@ -169,7 +181,9 @@ class MultiTaskLasso(BaseTransformer):
169
181
  self.set_label_cols(label_cols)
170
182
  self.set_passthrough_cols(passthrough_cols)
171
183
  self.set_drop_input_cols(drop_input_cols)
172
- self.set_sample_weight_col(sample_weight_col)
184
+ self.set_sample_weight_col(sample_weight_col)
185
+ self._use_external_memory_version = False
186
+ self._batch_size = -1
173
187
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
174
188
 
175
189
  self._deps = list(deps)
@@ -252,11 +266,6 @@ class MultiTaskLasso(BaseTransformer):
252
266
  if isinstance(dataset, DataFrame):
253
267
  session = dataset._session
254
268
  assert session is not None # keep mypy happy
255
- # Validate that key package version in user workspace are supported in snowflake conda channel
256
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
257
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
258
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
259
-
260
269
  # Specify input columns so column pruning will be enforced
261
270
  selected_cols = self._get_active_columns()
262
271
  if len(selected_cols) > 0:
@@ -284,7 +293,9 @@ class MultiTaskLasso(BaseTransformer):
284
293
  label_cols=self.label_cols,
285
294
  sample_weight_col=self.sample_weight_col,
286
295
  autogenerated=self._autogenerated,
287
- subproject=_SUBPROJECT
296
+ subproject=_SUBPROJECT,
297
+ use_external_memory_version=self._use_external_memory_version,
298
+ batch_size=self._batch_size,
288
299
  )
289
300
  self._sklearn_object = model_trainer.train()
290
301
  self._is_fitted = True
@@ -555,6 +566,22 @@ class MultiTaskLasso(BaseTransformer):
555
566
  # each row containing a list of values.
556
567
  expected_dtype = "ARRAY"
557
568
 
569
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
570
+ if expected_dtype == "":
571
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
572
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
573
+ expected_dtype = "ARRAY"
574
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
575
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
576
+ expected_dtype = "ARRAY"
577
+ else:
578
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
579
+ # We can only infer the output types from the input types if the following two statemetns are true:
580
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
581
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
582
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
583
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
584
+
558
585
  output_df = self._batch_inference(
559
586
  dataset=dataset,
560
587
  inference_method="transform",
@@ -570,8 +597,8 @@ class MultiTaskLasso(BaseTransformer):
570
597
 
571
598
  return output_df
572
599
 
573
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
574
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
600
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
601
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
575
602
  """ Method not supported for this class.
576
603
 
577
604
 
@@ -584,13 +611,21 @@ class MultiTaskLasso(BaseTransformer):
584
611
  Returns:
585
612
  Predicted dataset.
586
613
  """
587
- if False:
588
- self.fit(dataset)
589
- assert self._sklearn_object is not None
590
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
591
- return labels
592
- else:
593
- raise NotImplementedError
614
+ self.fit(dataset)
615
+ assert self._sklearn_object is not None
616
+ return self._sklearn_object.labels_
617
+
618
+
619
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
620
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
621
+ """
622
+ Returns:
623
+ Transformed dataset.
624
+ """
625
+ self.fit(dataset)
626
+ assert self._sklearn_object is not None
627
+ return self._sklearn_object.embedding_
628
+
594
629
 
595
630
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
596
631
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.