snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class BernoulliNB(BaseTransformer):
58
70
  r"""Naive Bayes classifier for multivariate Bernoulli models
59
71
  For more details on this class, see [sklearn.naive_bayes.BernoulliNB]
@@ -150,7 +162,9 @@ class BernoulliNB(BaseTransformer):
150
162
  self.set_label_cols(label_cols)
151
163
  self.set_passthrough_cols(passthrough_cols)
152
164
  self.set_drop_input_cols(drop_input_cols)
153
- self.set_sample_weight_col(sample_weight_col)
165
+ self.set_sample_weight_col(sample_weight_col)
166
+ self._use_external_memory_version = False
167
+ self._batch_size = -1
154
168
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
155
169
 
156
170
  self._deps = list(deps)
@@ -230,11 +244,6 @@ class BernoulliNB(BaseTransformer):
230
244
  if isinstance(dataset, DataFrame):
231
245
  session = dataset._session
232
246
  assert session is not None # keep mypy happy
233
- # Validate that key package version in user workspace are supported in snowflake conda channel
234
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
235
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
236
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
237
-
238
247
  # Specify input columns so column pruning will be enforced
239
248
  selected_cols = self._get_active_columns()
240
249
  if len(selected_cols) > 0:
@@ -262,7 +271,9 @@ class BernoulliNB(BaseTransformer):
262
271
  label_cols=self.label_cols,
263
272
  sample_weight_col=self.sample_weight_col,
264
273
  autogenerated=self._autogenerated,
265
- subproject=_SUBPROJECT
274
+ subproject=_SUBPROJECT,
275
+ use_external_memory_version=self._use_external_memory_version,
276
+ batch_size=self._batch_size,
266
277
  )
267
278
  self._sklearn_object = model_trainer.train()
268
279
  self._is_fitted = True
@@ -533,6 +544,22 @@ class BernoulliNB(BaseTransformer):
533
544
  # each row containing a list of values.
534
545
  expected_dtype = "ARRAY"
535
546
 
547
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
548
+ if expected_dtype == "":
549
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
550
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
553
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
554
+ expected_dtype = "ARRAY"
555
+ else:
556
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
557
+ # We can only infer the output types from the input types if the following two statemetns are true:
558
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
559
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
560
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
561
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
562
+
536
563
  output_df = self._batch_inference(
537
564
  dataset=dataset,
538
565
  inference_method="transform",
@@ -548,8 +575,8 @@ class BernoulliNB(BaseTransformer):
548
575
 
549
576
  return output_df
550
577
 
551
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
552
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
578
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
579
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
553
580
  """ Method not supported for this class.
554
581
 
555
582
 
@@ -562,13 +589,21 @@ class BernoulliNB(BaseTransformer):
562
589
  Returns:
563
590
  Predicted dataset.
564
591
  """
565
- if False:
566
- self.fit(dataset)
567
- assert self._sklearn_object is not None
568
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
569
- return labels
570
- else:
571
- raise NotImplementedError
592
+ self.fit(dataset)
593
+ assert self._sklearn_object is not None
594
+ return self._sklearn_object.labels_
595
+
596
+
597
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
598
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
599
+ """
600
+ Returns:
601
+ Transformed dataset.
602
+ """
603
+ self.fit(dataset)
604
+ assert self._sklearn_object is not None
605
+ return self._sklearn_object.embedding_
606
+
572
607
 
573
608
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
574
609
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class CategoricalNB(BaseTransformer):
58
70
  r"""Naive Bayes classifier for categorical features
59
71
  For more details on this class, see [sklearn.naive_bayes.CategoricalNB]
@@ -156,7 +168,9 @@ class CategoricalNB(BaseTransformer):
156
168
  self.set_label_cols(label_cols)
157
169
  self.set_passthrough_cols(passthrough_cols)
158
170
  self.set_drop_input_cols(drop_input_cols)
159
- self.set_sample_weight_col(sample_weight_col)
171
+ self.set_sample_weight_col(sample_weight_col)
172
+ self._use_external_memory_version = False
173
+ self._batch_size = -1
160
174
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
161
175
 
162
176
  self._deps = list(deps)
@@ -236,11 +250,6 @@ class CategoricalNB(BaseTransformer):
236
250
  if isinstance(dataset, DataFrame):
237
251
  session = dataset._session
238
252
  assert session is not None # keep mypy happy
239
- # Validate that key package version in user workspace are supported in snowflake conda channel
240
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
241
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
242
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
243
-
244
253
  # Specify input columns so column pruning will be enforced
245
254
  selected_cols = self._get_active_columns()
246
255
  if len(selected_cols) > 0:
@@ -268,7 +277,9 @@ class CategoricalNB(BaseTransformer):
268
277
  label_cols=self.label_cols,
269
278
  sample_weight_col=self.sample_weight_col,
270
279
  autogenerated=self._autogenerated,
271
- subproject=_SUBPROJECT
280
+ subproject=_SUBPROJECT,
281
+ use_external_memory_version=self._use_external_memory_version,
282
+ batch_size=self._batch_size,
272
283
  )
273
284
  self._sklearn_object = model_trainer.train()
274
285
  self._is_fitted = True
@@ -539,6 +550,22 @@ class CategoricalNB(BaseTransformer):
539
550
  # each row containing a list of values.
540
551
  expected_dtype = "ARRAY"
541
552
 
553
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
554
+ if expected_dtype == "":
555
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
556
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
557
+ expected_dtype = "ARRAY"
558
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
559
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
560
+ expected_dtype = "ARRAY"
561
+ else:
562
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
563
+ # We can only infer the output types from the input types if the following two statemetns are true:
564
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
565
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
566
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
567
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
568
+
542
569
  output_df = self._batch_inference(
543
570
  dataset=dataset,
544
571
  inference_method="transform",
@@ -554,8 +581,8 @@ class CategoricalNB(BaseTransformer):
554
581
 
555
582
  return output_df
556
583
 
557
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
558
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
584
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
585
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
559
586
  """ Method not supported for this class.
560
587
 
561
588
 
@@ -568,13 +595,21 @@ class CategoricalNB(BaseTransformer):
568
595
  Returns:
569
596
  Predicted dataset.
570
597
  """
571
- if False:
572
- self.fit(dataset)
573
- assert self._sklearn_object is not None
574
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
575
- return labels
576
- else:
577
- raise NotImplementedError
598
+ self.fit(dataset)
599
+ assert self._sklearn_object is not None
600
+ return self._sklearn_object.labels_
601
+
602
+
603
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
604
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
605
+ """
606
+ Returns:
607
+ Transformed dataset.
608
+ """
609
+ self.fit(dataset)
610
+ assert self._sklearn_object is not None
611
+ return self._sklearn_object.embedding_
612
+
578
613
 
579
614
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
580
615
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class ComplementNB(BaseTransformer):
58
70
  r"""The Complement Naive Bayes classifier described in Rennie et al
59
71
  For more details on this class, see [sklearn.naive_bayes.ComplementNB]
@@ -150,7 +162,9 @@ class ComplementNB(BaseTransformer):
150
162
  self.set_label_cols(label_cols)
151
163
  self.set_passthrough_cols(passthrough_cols)
152
164
  self.set_drop_input_cols(drop_input_cols)
153
- self.set_sample_weight_col(sample_weight_col)
165
+ self.set_sample_weight_col(sample_weight_col)
166
+ self._use_external_memory_version = False
167
+ self._batch_size = -1
154
168
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
155
169
 
156
170
  self._deps = list(deps)
@@ -230,11 +244,6 @@ class ComplementNB(BaseTransformer):
230
244
  if isinstance(dataset, DataFrame):
231
245
  session = dataset._session
232
246
  assert session is not None # keep mypy happy
233
- # Validate that key package version in user workspace are supported in snowflake conda channel
234
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
235
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
236
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
237
-
238
247
  # Specify input columns so column pruning will be enforced
239
248
  selected_cols = self._get_active_columns()
240
249
  if len(selected_cols) > 0:
@@ -262,7 +271,9 @@ class ComplementNB(BaseTransformer):
262
271
  label_cols=self.label_cols,
263
272
  sample_weight_col=self.sample_weight_col,
264
273
  autogenerated=self._autogenerated,
265
- subproject=_SUBPROJECT
274
+ subproject=_SUBPROJECT,
275
+ use_external_memory_version=self._use_external_memory_version,
276
+ batch_size=self._batch_size,
266
277
  )
267
278
  self._sklearn_object = model_trainer.train()
268
279
  self._is_fitted = True
@@ -533,6 +544,22 @@ class ComplementNB(BaseTransformer):
533
544
  # each row containing a list of values.
534
545
  expected_dtype = "ARRAY"
535
546
 
547
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
548
+ if expected_dtype == "":
549
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
550
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
551
+ expected_dtype = "ARRAY"
552
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
553
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
554
+ expected_dtype = "ARRAY"
555
+ else:
556
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
557
+ # We can only infer the output types from the input types if the following two statemetns are true:
558
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
559
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
560
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
561
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
562
+
536
563
  output_df = self._batch_inference(
537
564
  dataset=dataset,
538
565
  inference_method="transform",
@@ -548,8 +575,8 @@ class ComplementNB(BaseTransformer):
548
575
 
549
576
  return output_df
550
577
 
551
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
552
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
578
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
579
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
553
580
  """ Method not supported for this class.
554
581
 
555
582
 
@@ -562,13 +589,21 @@ class ComplementNB(BaseTransformer):
562
589
  Returns:
563
590
  Predicted dataset.
564
591
  """
565
- if False:
566
- self.fit(dataset)
567
- assert self._sklearn_object is not None
568
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
569
- return labels
570
- else:
571
- raise NotImplementedError
592
+ self.fit(dataset)
593
+ assert self._sklearn_object is not None
594
+ return self._sklearn_object.labels_
595
+
596
+
597
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
598
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
599
+ """
600
+ Returns:
601
+ Transformed dataset.
602
+ """
603
+ self.fit(dataset)
604
+ assert self._sklearn_object is not None
605
+ return self._sklearn_object.embedding_
606
+
572
607
 
573
608
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
574
609
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class GaussianNB(BaseTransformer):
58
70
  r"""Gaussian Naive Bayes (GaussianNB)
59
71
  For more details on this class, see [sklearn.naive_bayes.GaussianNB]
@@ -134,7 +146,9 @@ class GaussianNB(BaseTransformer):
134
146
  self.set_label_cols(label_cols)
135
147
  self.set_passthrough_cols(passthrough_cols)
136
148
  self.set_drop_input_cols(drop_input_cols)
137
- self.set_sample_weight_col(sample_weight_col)
149
+ self.set_sample_weight_col(sample_weight_col)
150
+ self._use_external_memory_version = False
151
+ self._batch_size = -1
138
152
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
139
153
 
140
154
  self._deps = list(deps)
@@ -211,11 +225,6 @@ class GaussianNB(BaseTransformer):
211
225
  if isinstance(dataset, DataFrame):
212
226
  session = dataset._session
213
227
  assert session is not None # keep mypy happy
214
- # Validate that key package version in user workspace are supported in snowflake conda channel
215
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
216
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
217
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
218
-
219
228
  # Specify input columns so column pruning will be enforced
220
229
  selected_cols = self._get_active_columns()
221
230
  if len(selected_cols) > 0:
@@ -243,7 +252,9 @@ class GaussianNB(BaseTransformer):
243
252
  label_cols=self.label_cols,
244
253
  sample_weight_col=self.sample_weight_col,
245
254
  autogenerated=self._autogenerated,
246
- subproject=_SUBPROJECT
255
+ subproject=_SUBPROJECT,
256
+ use_external_memory_version=self._use_external_memory_version,
257
+ batch_size=self._batch_size,
247
258
  )
248
259
  self._sklearn_object = model_trainer.train()
249
260
  self._is_fitted = True
@@ -514,6 +525,22 @@ class GaussianNB(BaseTransformer):
514
525
  # each row containing a list of values.
515
526
  expected_dtype = "ARRAY"
516
527
 
528
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
529
+ if expected_dtype == "":
530
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
531
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
532
+ expected_dtype = "ARRAY"
533
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
534
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
535
+ expected_dtype = "ARRAY"
536
+ else:
537
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
538
+ # We can only infer the output types from the input types if the following two statemetns are true:
539
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
540
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
541
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
542
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
543
+
517
544
  output_df = self._batch_inference(
518
545
  dataset=dataset,
519
546
  inference_method="transform",
@@ -529,8 +556,8 @@ class GaussianNB(BaseTransformer):
529
556
 
530
557
  return output_df
531
558
 
532
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
533
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
559
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
560
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
534
561
  """ Method not supported for this class.
535
562
 
536
563
 
@@ -543,13 +570,21 @@ class GaussianNB(BaseTransformer):
543
570
  Returns:
544
571
  Predicted dataset.
545
572
  """
546
- if False:
547
- self.fit(dataset)
548
- assert self._sklearn_object is not None
549
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
550
- return labels
551
- else:
552
- raise NotImplementedError
573
+ self.fit(dataset)
574
+ assert self._sklearn_object is not None
575
+ return self._sklearn_object.labels_
576
+
577
+
578
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
579
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
580
+ """
581
+ Returns:
582
+ Transformed dataset.
583
+ """
584
+ self.fit(dataset)
585
+ assert self._sklearn_object is not None
586
+ return self._sklearn_object.embedding_
587
+
553
588
 
554
589
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
555
590
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class MultinomialNB(BaseTransformer):
58
70
  r"""Naive Bayes classifier for multinomial models
59
71
  For more details on this class, see [sklearn.naive_bayes.MultinomialNB]
@@ -145,7 +157,9 @@ class MultinomialNB(BaseTransformer):
145
157
  self.set_label_cols(label_cols)
146
158
  self.set_passthrough_cols(passthrough_cols)
147
159
  self.set_drop_input_cols(drop_input_cols)
148
- self.set_sample_weight_col(sample_weight_col)
160
+ self.set_sample_weight_col(sample_weight_col)
161
+ self._use_external_memory_version = False
162
+ self._batch_size = -1
149
163
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
150
164
 
151
165
  self._deps = list(deps)
@@ -224,11 +238,6 @@ class MultinomialNB(BaseTransformer):
224
238
  if isinstance(dataset, DataFrame):
225
239
  session = dataset._session
226
240
  assert session is not None # keep mypy happy
227
- # Validate that key package version in user workspace are supported in snowflake conda channel
228
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
229
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
230
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
231
-
232
241
  # Specify input columns so column pruning will be enforced
233
242
  selected_cols = self._get_active_columns()
234
243
  if len(selected_cols) > 0:
@@ -256,7 +265,9 @@ class MultinomialNB(BaseTransformer):
256
265
  label_cols=self.label_cols,
257
266
  sample_weight_col=self.sample_weight_col,
258
267
  autogenerated=self._autogenerated,
259
- subproject=_SUBPROJECT
268
+ subproject=_SUBPROJECT,
269
+ use_external_memory_version=self._use_external_memory_version,
270
+ batch_size=self._batch_size,
260
271
  )
261
272
  self._sklearn_object = model_trainer.train()
262
273
  self._is_fitted = True
@@ -527,6 +538,22 @@ class MultinomialNB(BaseTransformer):
527
538
  # each row containing a list of values.
528
539
  expected_dtype = "ARRAY"
529
540
 
541
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
542
+ if expected_dtype == "":
543
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
544
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
545
+ expected_dtype = "ARRAY"
546
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
547
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
548
+ expected_dtype = "ARRAY"
549
+ else:
550
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
551
+ # We can only infer the output types from the input types if the following two statemetns are true:
552
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
553
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
554
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
555
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
556
+
530
557
  output_df = self._batch_inference(
531
558
  dataset=dataset,
532
559
  inference_method="transform",
@@ -542,8 +569,8 @@ class MultinomialNB(BaseTransformer):
542
569
 
543
570
  return output_df
544
571
 
545
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
546
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
572
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
573
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
547
574
  """ Method not supported for this class.
548
575
 
549
576
 
@@ -556,13 +583,21 @@ class MultinomialNB(BaseTransformer):
556
583
  Returns:
557
584
  Predicted dataset.
558
585
  """
559
- if False:
560
- self.fit(dataset)
561
- assert self._sklearn_object is not None
562
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
563
- return labels
564
- else:
565
- raise NotImplementedError
586
+ self.fit(dataset)
587
+ assert self._sklearn_object is not None
588
+ return self._sklearn_object.labels_
589
+
590
+
591
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
592
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
593
+ """
594
+ Returns:
595
+ Transformed dataset.
596
+ """
597
+ self.fit(dataset)
598
+ assert self._sklearn_object is not None
599
+ return self._sklearn_object.embedding_
600
+
566
601
 
567
602
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
568
603
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.