snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class PolynomialCountSketch(BaseTransformer):
58
70
  r"""Polynomial kernel approximation via Tensor Sketch
59
71
  For more details on this class, see [sklearn.kernel_approximation.PolynomialCountSketch]
@@ -151,7 +163,9 @@ class PolynomialCountSketch(BaseTransformer):
151
163
  self.set_label_cols(label_cols)
152
164
  self.set_passthrough_cols(passthrough_cols)
153
165
  self.set_drop_input_cols(drop_input_cols)
154
- self.set_sample_weight_col(sample_weight_col)
166
+ self.set_sample_weight_col(sample_weight_col)
167
+ self._use_external_memory_version = False
168
+ self._batch_size = -1
155
169
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
156
170
 
157
171
  self._deps = list(deps)
@@ -231,11 +245,6 @@ class PolynomialCountSketch(BaseTransformer):
231
245
  if isinstance(dataset, DataFrame):
232
246
  session = dataset._session
233
247
  assert session is not None # keep mypy happy
234
- # Validate that key package version in user workspace are supported in snowflake conda channel
235
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
236
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
237
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
238
-
239
248
  # Specify input columns so column pruning will be enforced
240
249
  selected_cols = self._get_active_columns()
241
250
  if len(selected_cols) > 0:
@@ -263,7 +272,9 @@ class PolynomialCountSketch(BaseTransformer):
263
272
  label_cols=self.label_cols,
264
273
  sample_weight_col=self.sample_weight_col,
265
274
  autogenerated=self._autogenerated,
266
- subproject=_SUBPROJECT
275
+ subproject=_SUBPROJECT,
276
+ use_external_memory_version=self._use_external_memory_version,
277
+ batch_size=self._batch_size,
267
278
  )
268
279
  self._sklearn_object = model_trainer.train()
269
280
  self._is_fitted = True
@@ -534,6 +545,22 @@ class PolynomialCountSketch(BaseTransformer):
534
545
  # each row containing a list of values.
535
546
  expected_dtype = "ARRAY"
536
547
 
548
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
549
+ if expected_dtype == "":
550
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
551
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
552
+ expected_dtype = "ARRAY"
553
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
554
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
555
+ expected_dtype = "ARRAY"
556
+ else:
557
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
558
+ # We can only infer the output types from the input types if the following two statemetns are true:
559
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
560
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
561
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
562
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
563
+
537
564
  output_df = self._batch_inference(
538
565
  dataset=dataset,
539
566
  inference_method="transform",
@@ -549,8 +576,8 @@ class PolynomialCountSketch(BaseTransformer):
549
576
 
550
577
  return output_df
551
578
 
552
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
553
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
579
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
580
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
554
581
  """ Method not supported for this class.
555
582
 
556
583
 
@@ -563,13 +590,21 @@ class PolynomialCountSketch(BaseTransformer):
563
590
  Returns:
564
591
  Predicted dataset.
565
592
  """
566
- if False:
567
- self.fit(dataset)
568
- assert self._sklearn_object is not None
569
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
570
- return labels
571
- else:
572
- raise NotImplementedError
593
+ self.fit(dataset)
594
+ assert self._sklearn_object is not None
595
+ return self._sklearn_object.labels_
596
+
597
+
598
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
599
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
600
+ """
601
+ Returns:
602
+ Transformed dataset.
603
+ """
604
+ self.fit(dataset)
605
+ assert self._sklearn_object is not None
606
+ return self._sklearn_object.embedding_
607
+
573
608
 
574
609
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
575
610
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class RBFSampler(BaseTransformer):
58
70
  r"""Approximate a RBF kernel feature map using random Fourier features
59
71
  For more details on this class, see [sklearn.kernel_approximation.RBFSampler]
@@ -140,7 +152,9 @@ class RBFSampler(BaseTransformer):
140
152
  self.set_label_cols(label_cols)
141
153
  self.set_passthrough_cols(passthrough_cols)
142
154
  self.set_drop_input_cols(drop_input_cols)
143
- self.set_sample_weight_col(sample_weight_col)
155
+ self.set_sample_weight_col(sample_weight_col)
156
+ self._use_external_memory_version = False
157
+ self._batch_size = -1
144
158
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
145
159
 
146
160
  self._deps = list(deps)
@@ -218,11 +232,6 @@ class RBFSampler(BaseTransformer):
218
232
  if isinstance(dataset, DataFrame):
219
233
  session = dataset._session
220
234
  assert session is not None # keep mypy happy
221
- # Validate that key package version in user workspace are supported in snowflake conda channel
222
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
223
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
224
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
225
-
226
235
  # Specify input columns so column pruning will be enforced
227
236
  selected_cols = self._get_active_columns()
228
237
  if len(selected_cols) > 0:
@@ -250,7 +259,9 @@ class RBFSampler(BaseTransformer):
250
259
  label_cols=self.label_cols,
251
260
  sample_weight_col=self.sample_weight_col,
252
261
  autogenerated=self._autogenerated,
253
- subproject=_SUBPROJECT
262
+ subproject=_SUBPROJECT,
263
+ use_external_memory_version=self._use_external_memory_version,
264
+ batch_size=self._batch_size,
254
265
  )
255
266
  self._sklearn_object = model_trainer.train()
256
267
  self._is_fitted = True
@@ -521,6 +532,22 @@ class RBFSampler(BaseTransformer):
521
532
  # each row containing a list of values.
522
533
  expected_dtype = "ARRAY"
523
534
 
535
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
536
+ if expected_dtype == "":
537
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
538
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
539
+ expected_dtype = "ARRAY"
540
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
541
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
542
+ expected_dtype = "ARRAY"
543
+ else:
544
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
545
+ # We can only infer the output types from the input types if the following two statemetns are true:
546
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
547
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
548
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
549
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
550
+
524
551
  output_df = self._batch_inference(
525
552
  dataset=dataset,
526
553
  inference_method="transform",
@@ -536,8 +563,8 @@ class RBFSampler(BaseTransformer):
536
563
 
537
564
  return output_df
538
565
 
539
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
540
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
566
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
567
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
541
568
  """ Method not supported for this class.
542
569
 
543
570
 
@@ -550,13 +577,21 @@ class RBFSampler(BaseTransformer):
550
577
  Returns:
551
578
  Predicted dataset.
552
579
  """
553
- if False:
554
- self.fit(dataset)
555
- assert self._sklearn_object is not None
556
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
557
- return labels
558
- else:
559
- raise NotImplementedError
580
+ self.fit(dataset)
581
+ assert self._sklearn_object is not None
582
+ return self._sklearn_object.labels_
583
+
584
+
585
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
586
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
587
+ """
588
+ Returns:
589
+ Transformed dataset.
590
+ """
591
+ self.fit(dataset)
592
+ assert self._sklearn_object is not None
593
+ return self._sklearn_object.embedding_
594
+
560
595
 
561
596
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
562
597
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class SkewedChi2Sampler(BaseTransformer):
58
70
  r"""Approximate feature map for "skewed chi-squared" kernel
59
71
  For more details on this class, see [sklearn.kernel_approximation.SkewedChi2Sampler]
@@ -138,7 +150,9 @@ class SkewedChi2Sampler(BaseTransformer):
138
150
  self.set_label_cols(label_cols)
139
151
  self.set_passthrough_cols(passthrough_cols)
140
152
  self.set_drop_input_cols(drop_input_cols)
141
- self.set_sample_weight_col(sample_weight_col)
153
+ self.set_sample_weight_col(sample_weight_col)
154
+ self._use_external_memory_version = False
155
+ self._batch_size = -1
142
156
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
143
157
 
144
158
  self._deps = list(deps)
@@ -216,11 +230,6 @@ class SkewedChi2Sampler(BaseTransformer):
216
230
  if isinstance(dataset, DataFrame):
217
231
  session = dataset._session
218
232
  assert session is not None # keep mypy happy
219
- # Validate that key package version in user workspace are supported in snowflake conda channel
220
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
221
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
222
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
223
-
224
233
  # Specify input columns so column pruning will be enforced
225
234
  selected_cols = self._get_active_columns()
226
235
  if len(selected_cols) > 0:
@@ -248,7 +257,9 @@ class SkewedChi2Sampler(BaseTransformer):
248
257
  label_cols=self.label_cols,
249
258
  sample_weight_col=self.sample_weight_col,
250
259
  autogenerated=self._autogenerated,
251
- subproject=_SUBPROJECT
260
+ subproject=_SUBPROJECT,
261
+ use_external_memory_version=self._use_external_memory_version,
262
+ batch_size=self._batch_size,
252
263
  )
253
264
  self._sklearn_object = model_trainer.train()
254
265
  self._is_fitted = True
@@ -519,6 +530,22 @@ class SkewedChi2Sampler(BaseTransformer):
519
530
  # each row containing a list of values.
520
531
  expected_dtype = "ARRAY"
521
532
 
533
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
534
+ if expected_dtype == "":
535
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
536
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
537
+ expected_dtype = "ARRAY"
538
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
539
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
540
+ expected_dtype = "ARRAY"
541
+ else:
542
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
543
+ # We can only infer the output types from the input types if the following two statemetns are true:
544
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
545
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
546
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
547
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
548
+
522
549
  output_df = self._batch_inference(
523
550
  dataset=dataset,
524
551
  inference_method="transform",
@@ -534,8 +561,8 @@ class SkewedChi2Sampler(BaseTransformer):
534
561
 
535
562
  return output_df
536
563
 
537
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
538
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
564
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
565
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
539
566
  """ Method not supported for this class.
540
567
 
541
568
 
@@ -548,13 +575,21 @@ class SkewedChi2Sampler(BaseTransformer):
548
575
  Returns:
549
576
  Predicted dataset.
550
577
  """
551
- if False:
552
- self.fit(dataset)
553
- assert self._sklearn_object is not None
554
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
555
- return labels
556
- else:
557
- raise NotImplementedError
578
+ self.fit(dataset)
579
+ assert self._sklearn_object is not None
580
+ return self._sklearn_object.labels_
581
+
582
+
583
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
584
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
585
+ """
586
+ Returns:
587
+ Transformed dataset.
588
+ """
589
+ self.fit(dataset)
590
+ assert self._sklearn_object is not None
591
+ return self._sklearn_object.embedding_
592
+
558
593
 
559
594
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
560
595
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_ridge".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class KernelRidge(BaseTransformer):
58
70
  r"""Kernel ridge regression
59
71
  For more details on this class, see [sklearn.kernel_ridge.KernelRidge]
@@ -171,7 +183,9 @@ class KernelRidge(BaseTransformer):
171
183
  self.set_label_cols(label_cols)
172
184
  self.set_passthrough_cols(passthrough_cols)
173
185
  self.set_drop_input_cols(drop_input_cols)
174
- self.set_sample_weight_col(sample_weight_col)
186
+ self.set_sample_weight_col(sample_weight_col)
187
+ self._use_external_memory_version = False
188
+ self._batch_size = -1
175
189
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
176
190
 
177
191
  self._deps = list(deps)
@@ -252,11 +266,6 @@ class KernelRidge(BaseTransformer):
252
266
  if isinstance(dataset, DataFrame):
253
267
  session = dataset._session
254
268
  assert session is not None # keep mypy happy
255
- # Validate that key package version in user workspace are supported in snowflake conda channel
256
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
257
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
258
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
259
-
260
269
  # Specify input columns so column pruning will be enforced
261
270
  selected_cols = self._get_active_columns()
262
271
  if len(selected_cols) > 0:
@@ -284,7 +293,9 @@ class KernelRidge(BaseTransformer):
284
293
  label_cols=self.label_cols,
285
294
  sample_weight_col=self.sample_weight_col,
286
295
  autogenerated=self._autogenerated,
287
- subproject=_SUBPROJECT
296
+ subproject=_SUBPROJECT,
297
+ use_external_memory_version=self._use_external_memory_version,
298
+ batch_size=self._batch_size,
288
299
  )
289
300
  self._sklearn_object = model_trainer.train()
290
301
  self._is_fitted = True
@@ -555,6 +566,22 @@ class KernelRidge(BaseTransformer):
555
566
  # each row containing a list of values.
556
567
  expected_dtype = "ARRAY"
557
568
 
569
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
570
+ if expected_dtype == "":
571
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
572
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
573
+ expected_dtype = "ARRAY"
574
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
575
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
576
+ expected_dtype = "ARRAY"
577
+ else:
578
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
579
+ # We can only infer the output types from the input types if the following two statemetns are true:
580
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
581
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
582
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
583
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
584
+
558
585
  output_df = self._batch_inference(
559
586
  dataset=dataset,
560
587
  inference_method="transform",
@@ -570,8 +597,8 @@ class KernelRidge(BaseTransformer):
570
597
 
571
598
  return output_df
572
599
 
573
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
574
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
600
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
601
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
575
602
  """ Method not supported for this class.
576
603
 
577
604
 
@@ -584,13 +611,21 @@ class KernelRidge(BaseTransformer):
584
611
  Returns:
585
612
  Predicted dataset.
586
613
  """
587
- if False:
588
- self.fit(dataset)
589
- assert self._sklearn_object is not None
590
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
591
- return labels
592
- else:
593
- raise NotImplementedError
614
+ self.fit(dataset)
615
+ assert self._sklearn_object is not None
616
+ return self._sklearn_object.labels_
617
+
618
+
619
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
620
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
621
+ """
622
+ Returns:
623
+ Transformed dataset.
624
+ """
625
+ self.fit(dataset)
626
+ assert self._sklearn_object is not None
627
+ return self._sklearn_object.embedding_
628
+
594
629
 
595
630
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
596
631
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -53,6 +53,18 @@ _PROJECT = "ModelDevelopment"
53
53
  _SUBPROJECT = "".join([s.capitalize() for s in "lightgbm".replace("sklearn.", "").split("_")])
54
54
 
55
55
 
56
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
57
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
58
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
59
+ return check
60
+
61
+
62
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
+ return check
66
+
67
+
56
68
  class LGBMClassifier(BaseTransformer):
57
69
  r"""LightGBM classifier
58
70
  For more details on this class, see [lightgbm.LGBMClassifier]
@@ -144,7 +156,9 @@ class LGBMClassifier(BaseTransformer):
144
156
  self.set_label_cols(label_cols)
145
157
  self.set_passthrough_cols(passthrough_cols)
146
158
  self.set_drop_input_cols(drop_input_cols)
147
- self.set_sample_weight_col(sample_weight_col)
159
+ self.set_sample_weight_col(sample_weight_col)
160
+ self._use_external_memory_version = False
161
+ self._batch_size = -1
148
162
  deps: Set[str] = set([f'numpy=={np.__version__}', f'lightgbm=={lightgbm.__version__}', f'cloudpickle=={cp.__version__}'])
149
163
 
150
164
  self._deps = list(deps)
@@ -240,11 +254,6 @@ class LGBMClassifier(BaseTransformer):
240
254
  if isinstance(dataset, DataFrame):
241
255
  session = dataset._session
242
256
  assert session is not None # keep mypy happy
243
- # Validate that key package version in user workspace are supported in snowflake conda channel
244
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
245
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
246
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
247
-
248
257
  # Specify input columns so column pruning will be enforced
249
258
  selected_cols = self._get_active_columns()
250
259
  if len(selected_cols) > 0:
@@ -272,7 +281,9 @@ class LGBMClassifier(BaseTransformer):
272
281
  label_cols=self.label_cols,
273
282
  sample_weight_col=self.sample_weight_col,
274
283
  autogenerated=self._autogenerated,
275
- subproject=_SUBPROJECT
284
+ subproject=_SUBPROJECT,
285
+ use_external_memory_version=self._use_external_memory_version,
286
+ batch_size=self._batch_size,
276
287
  )
277
288
  self._sklearn_object = model_trainer.train()
278
289
  self._is_fitted = True
@@ -543,6 +554,22 @@ class LGBMClassifier(BaseTransformer):
543
554
  # each row containing a list of values.
544
555
  expected_dtype = "ARRAY"
545
556
 
557
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
558
+ if expected_dtype == "":
559
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
560
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
561
+ expected_dtype = "ARRAY"
562
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
563
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
564
+ expected_dtype = "ARRAY"
565
+ else:
566
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
567
+ # We can only infer the output types from the input types if the following two statemetns are true:
568
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
569
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
570
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
571
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
572
+
546
573
  output_df = self._batch_inference(
547
574
  dataset=dataset,
548
575
  inference_method="transform",
@@ -558,8 +585,8 @@ class LGBMClassifier(BaseTransformer):
558
585
 
559
586
  return output_df
560
587
 
561
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
562
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
588
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
589
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
563
590
  """ Method not supported for this class.
564
591
 
565
592
 
@@ -572,13 +599,21 @@ class LGBMClassifier(BaseTransformer):
572
599
  Returns:
573
600
  Predicted dataset.
574
601
  """
575
- if False:
576
- self.fit(dataset)
577
- assert self._sklearn_object is not None
578
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
579
- return labels
580
- else:
581
- raise NotImplementedError
602
+ self.fit(dataset)
603
+ assert self._sklearn_object is not None
604
+ return self._sklearn_object.labels_
605
+
606
+
607
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
608
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
609
+ """
610
+ Returns:
611
+ Transformed dataset.
612
+ """
613
+ self.fit(dataset)
614
+ assert self._sklearn_object is not None
615
+ return self._sklearn_object.embedding_
616
+
582
617
 
583
618
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
584
619
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.