snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GammaRegressor(BaseTransformer):
58
70
  r"""Generalized Linear Model with a Gamma distribution
59
71
  For more details on this class, see [sklearn.linear_model.GammaRegressor]
@@ -175,7 +187,9 @@ class GammaRegressor(BaseTransformer):
175
187
  self.set_label_cols(label_cols)
176
188
  self.set_passthrough_cols(passthrough_cols)
177
189
  self.set_drop_input_cols(drop_input_cols)
178
- self.set_sample_weight_col(sample_weight_col)
190
+ self.set_sample_weight_col(sample_weight_col)
191
+ self._use_external_memory_version = False
192
+ self._batch_size = -1
179
193
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
180
194
 
181
195
  self._deps = list(deps)
@@ -257,11 +271,6 @@ class GammaRegressor(BaseTransformer):
257
271
  if isinstance(dataset, DataFrame):
258
272
  session = dataset._session
259
273
  assert session is not None # keep mypy happy
260
- # Validate that key package version in user workspace are supported in snowflake conda channel
261
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
262
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
263
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
264
-
265
274
  # Specify input columns so column pruning will be enforced
266
275
  selected_cols = self._get_active_columns()
267
276
  if len(selected_cols) > 0:
@@ -289,7 +298,9 @@ class GammaRegressor(BaseTransformer):
289
298
  label_cols=self.label_cols,
290
299
  sample_weight_col=self.sample_weight_col,
291
300
  autogenerated=self._autogenerated,
292
- subproject=_SUBPROJECT
301
+ subproject=_SUBPROJECT,
302
+ use_external_memory_version=self._use_external_memory_version,
303
+ batch_size=self._batch_size,
293
304
  )
294
305
  self._sklearn_object = model_trainer.train()
295
306
  self._is_fitted = True
@@ -560,6 +571,22 @@ class GammaRegressor(BaseTransformer):
560
571
  # each row containing a list of values.
561
572
  expected_dtype = "ARRAY"
562
573
 
574
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
575
+ if expected_dtype == "":
576
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
577
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
578
+ expected_dtype = "ARRAY"
579
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
580
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
581
+ expected_dtype = "ARRAY"
582
+ else:
583
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
584
+ # We can only infer the output types from the input types if the following two statemetns are true:
585
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
586
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
587
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
588
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
589
+
563
590
  output_df = self._batch_inference(
564
591
  dataset=dataset,
565
592
  inference_method="transform",
@@ -575,8 +602,8 @@ class GammaRegressor(BaseTransformer):
575
602
 
576
603
  return output_df
577
604
 
578
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
579
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
605
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
606
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
580
607
  """ Method not supported for this class.
581
608
 
582
609
 
@@ -589,13 +616,21 @@ class GammaRegressor(BaseTransformer):
589
616
  Returns:
590
617
  Predicted dataset.
591
618
  """
592
- if False:
593
- self.fit(dataset)
594
- assert self._sklearn_object is not None
595
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
596
- return labels
597
- else:
598
- raise NotImplementedError
619
+ self.fit(dataset)
620
+ assert self._sklearn_object is not None
621
+ return self._sklearn_object.labels_
622
+
623
+
624
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
625
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
626
+ """
627
+ Returns:
628
+ Transformed dataset.
629
+ """
630
+ self.fit(dataset)
631
+ assert self._sklearn_object is not None
632
+ return self._sklearn_object.embedding_
633
+
599
634
 
600
635
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
601
636
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class HuberRegressor(BaseTransformer):
58
70
  r"""L2-regularized linear regression model that is robust to outliers
59
71
  For more details on this class, see [sklearn.linear_model.HuberRegressor]
@@ -159,7 +171,9 @@ class HuberRegressor(BaseTransformer):
159
171
  self.set_label_cols(label_cols)
160
172
  self.set_passthrough_cols(passthrough_cols)
161
173
  self.set_drop_input_cols(drop_input_cols)
162
- self.set_sample_weight_col(sample_weight_col)
174
+ self.set_sample_weight_col(sample_weight_col)
175
+ self._use_external_memory_version = False
176
+ self._batch_size = -1
163
177
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
164
178
 
165
179
  self._deps = list(deps)
@@ -240,11 +254,6 @@ class HuberRegressor(BaseTransformer):
240
254
  if isinstance(dataset, DataFrame):
241
255
  session = dataset._session
242
256
  assert session is not None # keep mypy happy
243
- # Validate that key package version in user workspace are supported in snowflake conda channel
244
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
245
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
246
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
247
-
248
257
  # Specify input columns so column pruning will be enforced
249
258
  selected_cols = self._get_active_columns()
250
259
  if len(selected_cols) > 0:
@@ -272,7 +281,9 @@ class HuberRegressor(BaseTransformer):
272
281
  label_cols=self.label_cols,
273
282
  sample_weight_col=self.sample_weight_col,
274
283
  autogenerated=self._autogenerated,
275
- subproject=_SUBPROJECT
284
+ subproject=_SUBPROJECT,
285
+ use_external_memory_version=self._use_external_memory_version,
286
+ batch_size=self._batch_size,
276
287
  )
277
288
  self._sklearn_object = model_trainer.train()
278
289
  self._is_fitted = True
@@ -543,6 +554,22 @@ class HuberRegressor(BaseTransformer):
543
554
  # each row containing a list of values.
544
555
  expected_dtype = "ARRAY"
545
556
 
557
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
558
+ if expected_dtype == "":
559
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
560
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
561
+ expected_dtype = "ARRAY"
562
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
563
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ else:
566
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
567
+ # We can only infer the output types from the input types if the following two statemetns are true:
568
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
569
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
570
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
571
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
572
+
546
573
  output_df = self._batch_inference(
547
574
  dataset=dataset,
548
575
  inference_method="transform",
@@ -558,8 +585,8 @@ class HuberRegressor(BaseTransformer):
558
585
 
559
586
  return output_df
560
587
 
561
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
562
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
588
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
589
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
563
590
  """ Method not supported for this class.
564
591
 
565
592
 
@@ -572,13 +599,21 @@ class HuberRegressor(BaseTransformer):
572
599
  Returns:
573
600
  Predicted dataset.
574
601
  """
575
- if False:
576
- self.fit(dataset)
577
- assert self._sklearn_object is not None
578
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
579
- return labels
580
- else:
581
- raise NotImplementedError
602
+ self.fit(dataset)
603
+ assert self._sklearn_object is not None
604
+ return self._sklearn_object.labels_
605
+
606
+
607
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
608
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
609
+ """
610
+ Returns:
611
+ Transformed dataset.
612
+ """
613
+ self.fit(dataset)
614
+ assert self._sklearn_object is not None
615
+ return self._sklearn_object.embedding_
616
+
582
617
 
583
618
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
584
619
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class Lars(BaseTransformer):
58
70
  r"""Least Angle Regression model a
59
71
  For more details on this class, see [sklearn.linear_model.Lars]
@@ -184,7 +196,9 @@ class Lars(BaseTransformer):
184
196
  self.set_label_cols(label_cols)
185
197
  self.set_passthrough_cols(passthrough_cols)
186
198
  self.set_drop_input_cols(drop_input_cols)
187
- self.set_sample_weight_col(sample_weight_col)
199
+ self.set_sample_weight_col(sample_weight_col)
200
+ self._use_external_memory_version = False
201
+ self._batch_size = -1
188
202
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
189
203
 
190
204
  self._deps = list(deps)
@@ -269,11 +283,6 @@ class Lars(BaseTransformer):
269
283
  if isinstance(dataset, DataFrame):
270
284
  session = dataset._session
271
285
  assert session is not None # keep mypy happy
272
- # Validate that key package version in user workspace are supported in snowflake conda channel
273
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
274
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
275
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
276
-
277
286
  # Specify input columns so column pruning will be enforced
278
287
  selected_cols = self._get_active_columns()
279
288
  if len(selected_cols) > 0:
@@ -301,7 +310,9 @@ class Lars(BaseTransformer):
301
310
  label_cols=self.label_cols,
302
311
  sample_weight_col=self.sample_weight_col,
303
312
  autogenerated=self._autogenerated,
304
- subproject=_SUBPROJECT
313
+ subproject=_SUBPROJECT,
314
+ use_external_memory_version=self._use_external_memory_version,
315
+ batch_size=self._batch_size,
305
316
  )
306
317
  self._sklearn_object = model_trainer.train()
307
318
  self._is_fitted = True
@@ -572,6 +583,22 @@ class Lars(BaseTransformer):
572
583
  # each row containing a list of values.
573
584
  expected_dtype = "ARRAY"
574
585
 
586
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
587
+ if expected_dtype == "":
588
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
589
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
590
+ expected_dtype = "ARRAY"
591
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
592
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
593
+ expected_dtype = "ARRAY"
594
+ else:
595
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
596
+ # We can only infer the output types from the input types if the following two statemetns are true:
597
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
598
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
599
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
600
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
601
+
575
602
  output_df = self._batch_inference(
576
603
  dataset=dataset,
577
604
  inference_method="transform",
@@ -587,8 +614,8 @@ class Lars(BaseTransformer):
587
614
 
588
615
  return output_df
589
616
 
590
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
591
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
617
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
618
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
592
619
  """ Method not supported for this class.
593
620
 
594
621
 
@@ -601,13 +628,21 @@ class Lars(BaseTransformer):
601
628
  Returns:
602
629
  Predicted dataset.
603
630
  """
604
- if False:
605
- self.fit(dataset)
606
- assert self._sklearn_object is not None
607
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
608
- return labels
609
- else:
610
- raise NotImplementedError
631
+ self.fit(dataset)
632
+ assert self._sklearn_object is not None
633
+ return self._sklearn_object.labels_
634
+
635
+
636
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
637
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
638
+ """
639
+ Returns:
640
+ Transformed dataset.
641
+ """
642
+ self.fit(dataset)
643
+ assert self._sklearn_object is not None
644
+ return self._sklearn_object.embedding_
645
+
611
646
 
612
647
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
613
648
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LarsCV(BaseTransformer):
58
70
  r"""Cross-validated Least Angle Regression model
59
71
  For more details on this class, see [sklearn.linear_model.LarsCV]
@@ -192,7 +204,9 @@ class LarsCV(BaseTransformer):
192
204
  self.set_label_cols(label_cols)
193
205
  self.set_passthrough_cols(passthrough_cols)
194
206
  self.set_drop_input_cols(drop_input_cols)
195
- self.set_sample_weight_col(sample_weight_col)
207
+ self.set_sample_weight_col(sample_weight_col)
208
+ self._use_external_memory_version = False
209
+ self._batch_size = -1
196
210
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
197
211
 
198
212
  self._deps = list(deps)
@@ -277,11 +291,6 @@ class LarsCV(BaseTransformer):
277
291
  if isinstance(dataset, DataFrame):
278
292
  session = dataset._session
279
293
  assert session is not None # keep mypy happy
280
- # Validate that key package version in user workspace are supported in snowflake conda channel
281
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
282
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
283
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
284
-
285
294
  # Specify input columns so column pruning will be enforced
286
295
  selected_cols = self._get_active_columns()
287
296
  if len(selected_cols) > 0:
@@ -309,7 +318,9 @@ class LarsCV(BaseTransformer):
309
318
  label_cols=self.label_cols,
310
319
  sample_weight_col=self.sample_weight_col,
311
320
  autogenerated=self._autogenerated,
312
- subproject=_SUBPROJECT
321
+ subproject=_SUBPROJECT,
322
+ use_external_memory_version=self._use_external_memory_version,
323
+ batch_size=self._batch_size,
313
324
  )
314
325
  self._sklearn_object = model_trainer.train()
315
326
  self._is_fitted = True
@@ -580,6 +591,22 @@ class LarsCV(BaseTransformer):
580
591
  # each row containing a list of values.
581
592
  expected_dtype = "ARRAY"
582
593
 
594
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
595
+ if expected_dtype == "":
596
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
597
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
598
+ expected_dtype = "ARRAY"
599
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
600
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
601
+ expected_dtype = "ARRAY"
602
+ else:
603
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
604
+ # We can only infer the output types from the input types if the following two statemetns are true:
605
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
606
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
607
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
608
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
609
+
583
610
  output_df = self._batch_inference(
584
611
  dataset=dataset,
585
612
  inference_method="transform",
@@ -595,8 +622,8 @@ class LarsCV(BaseTransformer):
595
622
 
596
623
  return output_df
597
624
 
598
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
599
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
625
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
626
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
600
627
  """ Method not supported for this class.
601
628
 
602
629
 
@@ -609,13 +636,21 @@ class LarsCV(BaseTransformer):
609
636
  Returns:
610
637
  Predicted dataset.
611
638
  """
612
- if False:
613
- self.fit(dataset)
614
- assert self._sklearn_object is not None
615
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
616
- return labels
617
- else:
618
- raise NotImplementedError
639
+ self.fit(dataset)
640
+ assert self._sklearn_object is not None
641
+ return self._sklearn_object.labels_
642
+
643
+
644
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
645
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
646
+ """
647
+ Returns:
648
+ Transformed dataset.
649
+ """
650
+ self.fit(dataset)
651
+ assert self._sklearn_object is not None
652
+ return self._sklearn_object.embedding_
653
+
619
654
 
620
655
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
621
656
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class Lasso(BaseTransformer):
58
70
  r"""Linear Model trained with L1 prior as regularizer (aka the Lasso)
59
71
  For more details on this class, see [sklearn.linear_model.Lasso]
@@ -185,7 +197,9 @@ class Lasso(BaseTransformer):
185
197
  self.set_label_cols(label_cols)
186
198
  self.set_passthrough_cols(passthrough_cols)
187
199
  self.set_drop_input_cols(drop_input_cols)
188
- self.set_sample_weight_col(sample_weight_col)
200
+ self.set_sample_weight_col(sample_weight_col)
201
+ self._use_external_memory_version = False
202
+ self._batch_size = -1
189
203
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
190
204
 
191
205
  self._deps = list(deps)
@@ -270,11 +284,6 @@ class Lasso(BaseTransformer):
270
284
  if isinstance(dataset, DataFrame):
271
285
  session = dataset._session
272
286
  assert session is not None # keep mypy happy
273
- # Validate that key package version in user workspace are supported in snowflake conda channel
274
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
275
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
276
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
277
-
278
287
  # Specify input columns so column pruning will be enforced
279
288
  selected_cols = self._get_active_columns()
280
289
  if len(selected_cols) > 0:
@@ -302,7 +311,9 @@ class Lasso(BaseTransformer):
302
311
  label_cols=self.label_cols,
303
312
  sample_weight_col=self.sample_weight_col,
304
313
  autogenerated=self._autogenerated,
305
- subproject=_SUBPROJECT
314
+ subproject=_SUBPROJECT,
315
+ use_external_memory_version=self._use_external_memory_version,
316
+ batch_size=self._batch_size,
306
317
  )
307
318
  self._sklearn_object = model_trainer.train()
308
319
  self._is_fitted = True
@@ -573,6 +584,22 @@ class Lasso(BaseTransformer):
573
584
  # each row containing a list of values.
574
585
  expected_dtype = "ARRAY"
575
586
 
587
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
588
+ if expected_dtype == "":
589
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
590
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
591
+ expected_dtype = "ARRAY"
592
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
593
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
594
+ expected_dtype = "ARRAY"
595
+ else:
596
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
597
+ # We can only infer the output types from the input types if the following two statemetns are true:
598
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
599
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
600
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
601
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
602
+
576
603
  output_df = self._batch_inference(
577
604
  dataset=dataset,
578
605
  inference_method="transform",
@@ -588,8 +615,8 @@ class Lasso(BaseTransformer):
588
615
 
589
616
  return output_df
590
617
 
591
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
592
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
618
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
619
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
593
620
  """ Method not supported for this class.
594
621
 
595
622
 
@@ -602,13 +629,21 @@ class Lasso(BaseTransformer):
602
629
  Returns:
603
630
  Predicted dataset.
604
631
  """
605
- if False:
606
- self.fit(dataset)
607
- assert self._sklearn_object is not None
608
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
609
- return labels
610
- else:
611
- raise NotImplementedError
632
+ self.fit(dataset)
633
+ assert self._sklearn_object is not None
634
+ return self._sklearn_object.labels_
635
+
636
+
637
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
638
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
639
+ """
640
+ Returns:
641
+ Transformed dataset.
642
+ """
643
+ self.fit(dataset)
644
+ assert self._sklearn_object is not None
645
+ return self._sklearn_object.embedding_
646
+
612
647
 
613
648
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
614
649
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.