snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LinearSVC(BaseTransformer):
58
70
  r"""Linear Support Vector Classification
59
71
  For more details on this class, see [sklearn.svm.LinearSVC]
@@ -214,7 +226,9 @@ class LinearSVC(BaseTransformer):
214
226
  self.set_label_cols(label_cols)
215
227
  self.set_passthrough_cols(passthrough_cols)
216
228
  self.set_drop_input_cols(drop_input_cols)
217
- self.set_sample_weight_col(sample_weight_col)
229
+ self.set_sample_weight_col(sample_weight_col)
230
+ self._use_external_memory_version = False
231
+ self._batch_size = -1
218
232
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
219
233
 
220
234
  self._deps = list(deps)
@@ -301,11 +315,6 @@ class LinearSVC(BaseTransformer):
301
315
  if isinstance(dataset, DataFrame):
302
316
  session = dataset._session
303
317
  assert session is not None # keep mypy happy
304
- # Validate that key package version in user workspace are supported in snowflake conda channel
305
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
306
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
307
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
308
-
309
318
  # Specify input columns so column pruning will be enforced
310
319
  selected_cols = self._get_active_columns()
311
320
  if len(selected_cols) > 0:
@@ -333,7 +342,9 @@ class LinearSVC(BaseTransformer):
333
342
  label_cols=self.label_cols,
334
343
  sample_weight_col=self.sample_weight_col,
335
344
  autogenerated=self._autogenerated,
336
- subproject=_SUBPROJECT
345
+ subproject=_SUBPROJECT,
346
+ use_external_memory_version=self._use_external_memory_version,
347
+ batch_size=self._batch_size,
337
348
  )
338
349
  self._sklearn_object = model_trainer.train()
339
350
  self._is_fitted = True
@@ -604,6 +615,22 @@ class LinearSVC(BaseTransformer):
604
615
  # each row containing a list of values.
605
616
  expected_dtype = "ARRAY"
606
617
 
618
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
619
+ if expected_dtype == "":
620
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
621
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
622
+ expected_dtype = "ARRAY"
623
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
624
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
625
+ expected_dtype = "ARRAY"
626
+ else:
627
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
628
+ # We can only infer the output types from the input types if the following two statemetns are true:
629
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
630
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
631
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
632
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
633
+
607
634
  output_df = self._batch_inference(
608
635
  dataset=dataset,
609
636
  inference_method="transform",
@@ -619,8 +646,8 @@ class LinearSVC(BaseTransformer):
619
646
 
620
647
  return output_df
621
648
 
622
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
623
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
649
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
650
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
624
651
  """ Method not supported for this class.
625
652
 
626
653
 
@@ -633,13 +660,21 @@ class LinearSVC(BaseTransformer):
633
660
  Returns:
634
661
  Predicted dataset.
635
662
  """
636
- if False:
637
- self.fit(dataset)
638
- assert self._sklearn_object is not None
639
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
640
- return labels
641
- else:
642
- raise NotImplementedError
663
+ self.fit(dataset)
664
+ assert self._sklearn_object is not None
665
+ return self._sklearn_object.labels_
666
+
667
+
668
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
669
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
670
+ """
671
+ Returns:
672
+ Transformed dataset.
673
+ """
674
+ self.fit(dataset)
675
+ assert self._sklearn_object is not None
676
+ return self._sklearn_object.embedding_
677
+
643
678
 
644
679
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
645
680
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class LinearSVR(BaseTransformer):
58
70
  r"""Linear Support Vector Regression
59
71
  For more details on this class, see [sklearn.svm.LinearSVR]
@@ -188,7 +200,9 @@ class LinearSVR(BaseTransformer):
188
200
  self.set_label_cols(label_cols)
189
201
  self.set_passthrough_cols(passthrough_cols)
190
202
  self.set_drop_input_cols(drop_input_cols)
191
- self.set_sample_weight_col(sample_weight_col)
203
+ self.set_sample_weight_col(sample_weight_col)
204
+ self._use_external_memory_version = False
205
+ self._batch_size = -1
192
206
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
193
207
 
194
208
  self._deps = list(deps)
@@ -273,11 +287,6 @@ class LinearSVR(BaseTransformer):
273
287
  if isinstance(dataset, DataFrame):
274
288
  session = dataset._session
275
289
  assert session is not None # keep mypy happy
276
- # Validate that key package version in user workspace are supported in snowflake conda channel
277
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
278
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
279
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
280
-
281
290
  # Specify input columns so column pruning will be enforced
282
291
  selected_cols = self._get_active_columns()
283
292
  if len(selected_cols) > 0:
@@ -305,7 +314,9 @@ class LinearSVR(BaseTransformer):
305
314
  label_cols=self.label_cols,
306
315
  sample_weight_col=self.sample_weight_col,
307
316
  autogenerated=self._autogenerated,
308
- subproject=_SUBPROJECT
317
+ subproject=_SUBPROJECT,
318
+ use_external_memory_version=self._use_external_memory_version,
319
+ batch_size=self._batch_size,
309
320
  )
310
321
  self._sklearn_object = model_trainer.train()
311
322
  self._is_fitted = True
@@ -576,6 +587,22 @@ class LinearSVR(BaseTransformer):
576
587
  # each row containing a list of values.
577
588
  expected_dtype = "ARRAY"
578
589
 
590
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
591
+ if expected_dtype == "":
592
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
593
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
594
+ expected_dtype = "ARRAY"
595
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
596
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
597
+ expected_dtype = "ARRAY"
598
+ else:
599
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
600
+ # We can only infer the output types from the input types if the following two statemetns are true:
601
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
602
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
603
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
604
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
605
+
579
606
  output_df = self._batch_inference(
580
607
  dataset=dataset,
581
608
  inference_method="transform",
@@ -591,8 +618,8 @@ class LinearSVR(BaseTransformer):
591
618
 
592
619
  return output_df
593
620
 
594
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
595
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
621
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
622
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
596
623
  """ Method not supported for this class.
597
624
 
598
625
 
@@ -605,13 +632,21 @@ class LinearSVR(BaseTransformer):
605
632
  Returns:
606
633
  Predicted dataset.
607
634
  """
608
- if False:
609
- self.fit(dataset)
610
- assert self._sklearn_object is not None
611
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
612
- return labels
613
- else:
614
- raise NotImplementedError
635
+ self.fit(dataset)
636
+ assert self._sklearn_object is not None
637
+ return self._sklearn_object.labels_
638
+
639
+
640
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
641
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
642
+ """
643
+ Returns:
644
+ Transformed dataset.
645
+ """
646
+ self.fit(dataset)
647
+ assert self._sklearn_object is not None
648
+ return self._sklearn_object.embedding_
649
+
615
650
 
616
651
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
617
652
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class NuSVC(BaseTransformer):
58
70
  r"""Nu-Support Vector Classification
59
71
  For more details on this class, see [sklearn.svm.NuSVC]
@@ -217,7 +229,9 @@ class NuSVC(BaseTransformer):
217
229
  self.set_label_cols(label_cols)
218
230
  self.set_passthrough_cols(passthrough_cols)
219
231
  self.set_drop_input_cols(drop_input_cols)
220
- self.set_sample_weight_col(sample_weight_col)
232
+ self.set_sample_weight_col(sample_weight_col)
233
+ self._use_external_memory_version = False
234
+ self._batch_size = -1
221
235
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
222
236
 
223
237
  self._deps = list(deps)
@@ -307,11 +321,6 @@ class NuSVC(BaseTransformer):
307
321
  if isinstance(dataset, DataFrame):
308
322
  session = dataset._session
309
323
  assert session is not None # keep mypy happy
310
- # Validate that key package version in user workspace are supported in snowflake conda channel
311
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
312
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
313
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
314
-
315
324
  # Specify input columns so column pruning will be enforced
316
325
  selected_cols = self._get_active_columns()
317
326
  if len(selected_cols) > 0:
@@ -339,7 +348,9 @@ class NuSVC(BaseTransformer):
339
348
  label_cols=self.label_cols,
340
349
  sample_weight_col=self.sample_weight_col,
341
350
  autogenerated=self._autogenerated,
342
- subproject=_SUBPROJECT
351
+ subproject=_SUBPROJECT,
352
+ use_external_memory_version=self._use_external_memory_version,
353
+ batch_size=self._batch_size,
343
354
  )
344
355
  self._sklearn_object = model_trainer.train()
345
356
  self._is_fitted = True
@@ -610,6 +621,22 @@ class NuSVC(BaseTransformer):
610
621
  # each row containing a list of values.
611
622
  expected_dtype = "ARRAY"
612
623
 
624
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
625
+ if expected_dtype == "":
626
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
627
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
628
+ expected_dtype = "ARRAY"
629
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
630
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
631
+ expected_dtype = "ARRAY"
632
+ else:
633
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
634
+ # We can only infer the output types from the input types if the following two statemetns are true:
635
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
636
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
637
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
638
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
639
+
613
640
  output_df = self._batch_inference(
614
641
  dataset=dataset,
615
642
  inference_method="transform",
@@ -625,8 +652,8 @@ class NuSVC(BaseTransformer):
625
652
 
626
653
  return output_df
627
654
 
628
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
629
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
655
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
656
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
630
657
  """ Method not supported for this class.
631
658
 
632
659
 
@@ -639,13 +666,21 @@ class NuSVC(BaseTransformer):
639
666
  Returns:
640
667
  Predicted dataset.
641
668
  """
642
- if False:
643
- self.fit(dataset)
644
- assert self._sklearn_object is not None
645
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
646
- return labels
647
- else:
648
- raise NotImplementedError
669
+ self.fit(dataset)
670
+ assert self._sklearn_object is not None
671
+ return self._sklearn_object.labels_
672
+
673
+
674
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
675
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
676
+ """
677
+ Returns:
678
+ Transformed dataset.
679
+ """
680
+ self.fit(dataset)
681
+ assert self._sklearn_object is not None
682
+ return self._sklearn_object.embedding_
683
+
649
684
 
650
685
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
651
686
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class NuSVR(BaseTransformer):
58
70
  r"""Nu Support Vector Regression
59
71
  For more details on this class, see [sklearn.svm.NuSVR]
@@ -182,7 +194,9 @@ class NuSVR(BaseTransformer):
182
194
  self.set_label_cols(label_cols)
183
195
  self.set_passthrough_cols(passthrough_cols)
184
196
  self.set_drop_input_cols(drop_input_cols)
185
- self.set_sample_weight_col(sample_weight_col)
197
+ self.set_sample_weight_col(sample_weight_col)
198
+ self._use_external_memory_version = False
199
+ self._batch_size = -1
186
200
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
187
201
 
188
202
  self._deps = list(deps)
@@ -268,11 +282,6 @@ class NuSVR(BaseTransformer):
268
282
  if isinstance(dataset, DataFrame):
269
283
  session = dataset._session
270
284
  assert session is not None # keep mypy happy
271
- # Validate that key package version in user workspace are supported in snowflake conda channel
272
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
273
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
274
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
275
-
276
285
  # Specify input columns so column pruning will be enforced
277
286
  selected_cols = self._get_active_columns()
278
287
  if len(selected_cols) > 0:
@@ -300,7 +309,9 @@ class NuSVR(BaseTransformer):
300
309
  label_cols=self.label_cols,
301
310
  sample_weight_col=self.sample_weight_col,
302
311
  autogenerated=self._autogenerated,
303
- subproject=_SUBPROJECT
312
+ subproject=_SUBPROJECT,
313
+ use_external_memory_version=self._use_external_memory_version,
314
+ batch_size=self._batch_size,
304
315
  )
305
316
  self._sklearn_object = model_trainer.train()
306
317
  self._is_fitted = True
@@ -571,6 +582,22 @@ class NuSVR(BaseTransformer):
571
582
  # each row containing a list of values.
572
583
  expected_dtype = "ARRAY"
573
584
 
585
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
586
+ if expected_dtype == "":
587
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
588
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
589
+ expected_dtype = "ARRAY"
590
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
591
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
592
+ expected_dtype = "ARRAY"
593
+ else:
594
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
595
+ # We can only infer the output types from the input types if the following two statemetns are true:
596
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
597
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
598
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
599
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
600
+
574
601
  output_df = self._batch_inference(
575
602
  dataset=dataset,
576
603
  inference_method="transform",
@@ -586,8 +613,8 @@ class NuSVR(BaseTransformer):
586
613
 
587
614
  return output_df
588
615
 
589
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
590
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
616
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
617
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
591
618
  """ Method not supported for this class.
592
619
 
593
620
 
@@ -600,13 +627,21 @@ class NuSVR(BaseTransformer):
600
627
  Returns:
601
628
  Predicted dataset.
602
629
  """
603
- if False:
604
- self.fit(dataset)
605
- assert self._sklearn_object is not None
606
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
607
- return labels
608
- else:
609
- raise NotImplementedError
630
+ self.fit(dataset)
631
+ assert self._sklearn_object is not None
632
+ return self._sklearn_object.labels_
633
+
634
+
635
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
636
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
637
+ """
638
+ Returns:
639
+ Transformed dataset.
640
+ """
641
+ self.fit(dataset)
642
+ assert self._sklearn_object is not None
643
+ return self._sklearn_object.embedding_
644
+
610
645
 
611
646
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
612
647
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SVC(BaseTransformer):
58
70
  r"""C-Support Vector Classification
59
71
  For more details on this class, see [sklearn.svm.SVC]
@@ -220,7 +232,9 @@ class SVC(BaseTransformer):
220
232
  self.set_label_cols(label_cols)
221
233
  self.set_passthrough_cols(passthrough_cols)
222
234
  self.set_drop_input_cols(drop_input_cols)
223
- self.set_sample_weight_col(sample_weight_col)
235
+ self.set_sample_weight_col(sample_weight_col)
236
+ self._use_external_memory_version = False
237
+ self._batch_size = -1
224
238
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
225
239
 
226
240
  self._deps = list(deps)
@@ -310,11 +324,6 @@ class SVC(BaseTransformer):
310
324
  if isinstance(dataset, DataFrame):
311
325
  session = dataset._session
312
326
  assert session is not None # keep mypy happy
313
- # Validate that key package version in user workspace are supported in snowflake conda channel
314
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
315
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
316
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
317
-
318
327
  # Specify input columns so column pruning will be enforced
319
328
  selected_cols = self._get_active_columns()
320
329
  if len(selected_cols) > 0:
@@ -342,7 +351,9 @@ class SVC(BaseTransformer):
342
351
  label_cols=self.label_cols,
343
352
  sample_weight_col=self.sample_weight_col,
344
353
  autogenerated=self._autogenerated,
345
- subproject=_SUBPROJECT
354
+ subproject=_SUBPROJECT,
355
+ use_external_memory_version=self._use_external_memory_version,
356
+ batch_size=self._batch_size,
346
357
  )
347
358
  self._sklearn_object = model_trainer.train()
348
359
  self._is_fitted = True
@@ -613,6 +624,22 @@ class SVC(BaseTransformer):
613
624
  # each row containing a list of values.
614
625
  expected_dtype = "ARRAY"
615
626
 
627
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
628
+ if expected_dtype == "":
629
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
630
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
631
+ expected_dtype = "ARRAY"
632
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
633
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
634
+ expected_dtype = "ARRAY"
635
+ else:
636
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
637
+ # We can only infer the output types from the input types if the following two statemetns are true:
638
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
639
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
640
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
641
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
642
+
616
643
  output_df = self._batch_inference(
617
644
  dataset=dataset,
618
645
  inference_method="transform",
@@ -628,8 +655,8 @@ class SVC(BaseTransformer):
628
655
 
629
656
  return output_df
630
657
 
631
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
632
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
658
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
659
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
633
660
  """ Method not supported for this class.
634
661
 
635
662
 
@@ -642,13 +669,21 @@ class SVC(BaseTransformer):
642
669
  Returns:
643
670
  Predicted dataset.
644
671
  """
645
- if False:
646
- self.fit(dataset)
647
- assert self._sklearn_object is not None
648
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
649
- return labels
650
- else:
651
- raise NotImplementedError
672
+ self.fit(dataset)
673
+ assert self._sklearn_object is not None
674
+ return self._sklearn_object.labels_
675
+
676
+
677
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
678
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
679
+ """
680
+ Returns:
681
+ Transformed dataset.
682
+ """
683
+ self.fit(dataset)
684
+ assert self._sklearn_object is not None
685
+ return self._sklearn_object.embedding_
686
+
652
687
 
653
688
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
654
689
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.