snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. snowflake/ml/_internal/env_utils.py +77 -32
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/dataset/__init__.py +10 -0
  10. snowflake/ml/dataset/dataset.py +454 -129
  11. snowflake/ml/dataset/dataset_factory.py +53 -0
  12. snowflake/ml/dataset/dataset_metadata.py +103 -0
  13. snowflake/ml/dataset/dataset_reader.py +202 -0
  14. snowflake/ml/feature_store/feature_store.py +531 -332
  15. snowflake/ml/feature_store/feature_view.py +40 -23
  16. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  17. snowflake/ml/fileset/sfcfs.py +56 -54
  18. snowflake/ml/fileset/snowfs.py +159 -0
  19. snowflake/ml/fileset/stage_fs.py +49 -17
  20. snowflake/ml/model/__init__.py +2 -2
  21. snowflake/ml/model/_api.py +16 -1
  22. snowflake/ml/model/_client/model/model_impl.py +27 -0
  23. snowflake/ml/model/_client/model/model_version_impl.py +137 -50
  24. snowflake/ml/model/_client/ops/model_ops.py +159 -40
  25. snowflake/ml/model/_client/sql/model.py +25 -2
  26. snowflake/ml/model/_client/sql/model_version.py +131 -2
  27. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  28. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  29. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  30. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  31. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  32. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
  34. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  36. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  37. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  39. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  40. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  43. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  44. snowflake/ml/model/_packager/model_packager.py +2 -5
  45. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  47. snowflake/ml/model/type_hints.py +21 -2
  48. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  49. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  50. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  51. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  52. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  53. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
  55. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
  56. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  57. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
  58. snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
  59. snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
  60. snowflake/ml/modeling/cluster/birch.py +248 -175
  61. snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
  62. snowflake/ml/modeling/cluster/dbscan.py +246 -175
  63. snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
  64. snowflake/ml/modeling/cluster/k_means.py +248 -175
  65. snowflake/ml/modeling/cluster/mean_shift.py +246 -175
  66. snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
  67. snowflake/ml/modeling/cluster/optics.py +246 -175
  68. snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
  69. snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
  70. snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
  71. snowflake/ml/modeling/compose/column_transformer.py +248 -175
  72. snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
  73. snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
  74. snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
  75. snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
  76. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
  77. snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
  78. snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
  79. snowflake/ml/modeling/covariance/oas.py +246 -175
  80. snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
  81. snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
  82. snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
  83. snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
  84. snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
  85. snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
  86. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
  87. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
  88. snowflake/ml/modeling/decomposition/pca.py +248 -175
  89. snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
  90. snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
  91. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
  92. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
  93. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
  94. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
  95. snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
  96. snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
  97. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
  98. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
  99. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
  100. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
  101. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
  102. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
  103. snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
  104. snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
  105. snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
  106. snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
  107. snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
  108. snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
  109. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
  110. snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
  111. snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
  112. snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
  113. snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
  114. snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
  115. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
  116. snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
  117. snowflake/ml/modeling/framework/_utils.py +8 -1
  118. snowflake/ml/modeling/framework/base.py +72 -37
  119. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
  120. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
  121. snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
  122. snowflake/ml/modeling/impute/knn_imputer.py +248 -175
  123. snowflake/ml/modeling/impute/missing_indicator.py +248 -175
  124. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
  125. snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
  126. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
  127. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
  128. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
  129. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
  130. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
  131. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
  132. snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
  133. snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
  134. snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
  135. snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
  136. snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
  137. snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
  138. snowflake/ml/modeling/linear_model/lars.py +246 -175
  139. snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
  140. snowflake/ml/modeling/linear_model/lasso.py +246 -175
  141. snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
  142. snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
  143. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
  144. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
  145. snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
  146. snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
  147. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
  148. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
  149. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
  150. snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
  151. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
  152. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
  153. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
  154. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
  155. snowflake/ml/modeling/linear_model/perceptron.py +246 -175
  156. snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
  157. snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
  158. snowflake/ml/modeling/linear_model/ridge.py +246 -175
  159. snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
  160. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
  161. snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
  162. snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
  163. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
  164. snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
  165. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
  166. snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
  167. snowflake/ml/modeling/manifold/isomap.py +248 -175
  168. snowflake/ml/modeling/manifold/mds.py +248 -175
  169. snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
  170. snowflake/ml/modeling/manifold/tsne.py +248 -175
  171. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
  172. snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
  173. snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
  174. snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
  175. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
  176. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
  177. snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
  178. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
  179. snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
  180. snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
  181. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
  182. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
  183. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
  184. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
  185. snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
  186. snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
  187. snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
  188. snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
  189. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
  190. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
  191. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
  192. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
  193. snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
  194. snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
  195. snowflake/ml/modeling/pipeline/pipeline.py +517 -35
  196. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  197. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  198. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  199. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  200. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  201. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  202. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
  203. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  204. snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
  205. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  206. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
  209. snowflake/ml/modeling/svm/linear_svc.py +246 -175
  210. snowflake/ml/modeling/svm/linear_svr.py +246 -175
  211. snowflake/ml/modeling/svm/nu_svc.py +246 -175
  212. snowflake/ml/modeling/svm/nu_svr.py +246 -175
  213. snowflake/ml/modeling/svm/svc.py +246 -175
  214. snowflake/ml/modeling/svm/svr.py +246 -175
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
  223. snowflake/ml/registry/model_registry.py +3 -149
  224. snowflake/ml/registry/registry.py +1 -1
  225. snowflake/ml/version.py +1 -1
  226. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
  227. snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
  228. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  229. snowflake/ml/registry/_artifact_manager.py +0 -156
  230. snowflake/ml/registry/artifact.py +0 -46
  231. snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
  232. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  233. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  234. {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
33
33
  BatchInferenceKwargsTypedDict,
34
34
  ScoreKwargsTypedDict
35
35
  )
36
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
+ from snowflake.ml.model.model_signature import (
38
+ BaseFeatureSpec,
39
+ DataType,
40
+ FeatureSpec,
41
+ ModelSignature,
42
+ _infer_signature,
43
+ _rename_signature_with_snowflake_identifiers,
44
+ )
36
45
 
37
46
  from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
38
47
 
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
43
52
  validate_sklearn_args,
44
53
  )
45
54
 
46
- from snowflake.ml.model.model_signature import (
47
- DataType,
48
- FeatureSpec,
49
- ModelSignature,
50
- _infer_signature,
51
- _rename_signature_with_snowflake_identifiers,
52
- BaseFeatureSpec,
53
- )
54
- from snowflake.ml.model._signatures import utils as model_signature_utils
55
-
56
55
  _PROJECT = "ModelDevelopment"
57
56
  # Derive subproject from module name by removing "sklearn"
58
57
  # and converting module name from underscore to CamelCase
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
61
60
 
62
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
63
62
 
64
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
65
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
66
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
67
- return check
68
-
69
-
70
63
  class MultinomialNB(BaseTransformer):
71
64
  r"""Naive Bayes classifier for multinomial models
72
65
  For more details on this class, see [sklearn.naive_bayes.MultinomialNB]
@@ -216,12 +209,7 @@ class MultinomialNB(BaseTransformer):
216
209
  )
217
210
  return selected_cols
218
211
 
219
- @telemetry.send_api_usage_telemetry(
220
- project=_PROJECT,
221
- subproject=_SUBPROJECT,
222
- custom_tags=dict([("autogen", True)]),
223
- )
224
- def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MultinomialNB":
212
+ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MultinomialNB":
225
213
  """Fit Naive Bayes classifier according to X, y
226
214
  For more details on this function, see [sklearn.naive_bayes.MultinomialNB.fit]
227
215
  (https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.MultinomialNB.html#sklearn.naive_bayes.MultinomialNB.fit)
@@ -248,12 +236,14 @@ class MultinomialNB(BaseTransformer):
248
236
 
249
237
  self._snowpark_cols = dataset.select(self.input_cols).columns
250
238
 
251
- # If we are already in a stored procedure, no need to kick off another one.
239
+ # If we are already in a stored procedure, no need to kick off another one.
252
240
  if SNOWML_SPROC_ENV in os.environ:
253
241
  statement_params = telemetry.get_function_usage_statement_params(
254
242
  project=_PROJECT,
255
243
  subproject=_SUBPROJECT,
256
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), MultinomialNB.__class__.__name__),
244
+ function_name=telemetry.get_statement_params_full_func_name(
245
+ inspect.currentframe(), MultinomialNB.__class__.__name__
246
+ ),
257
247
  api_calls=[Session.call],
258
248
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
259
249
  )
@@ -274,27 +264,24 @@ class MultinomialNB(BaseTransformer):
274
264
  )
275
265
  self._sklearn_object = model_trainer.train()
276
266
  self._is_fitted = True
277
- self._get_model_signatures(dataset)
267
+ self._generate_model_signatures(dataset)
278
268
  return self
279
269
 
280
270
  def _batch_inference_validate_snowpark(
281
271
  self,
282
272
  dataset: DataFrame,
283
273
  inference_method: str,
284
- ) -> List[str]:
285
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
286
- return the available package that exists in the snowflake anaconda channel
274
+ ) -> None:
275
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
287
276
 
288
277
  Args:
289
278
  dataset: snowpark dataframe
290
279
  inference_method: the inference method such as predict, score...
291
-
280
+
292
281
  Raises:
293
282
  SnowflakeMLException: If the estimator is not fitted, raise error
294
283
  SnowflakeMLException: If the session is None, raise error
295
284
 
296
- Returns:
297
- A list of available package that exists in the snowflake anaconda channel
298
285
  """
299
286
  if not self._is_fitted:
300
287
  raise exceptions.SnowflakeMLException(
@@ -312,9 +299,7 @@ class MultinomialNB(BaseTransformer):
312
299
  "Session must not specified for snowpark dataset."
313
300
  ),
314
301
  )
315
- # Validate that key package version in user workspace are supported in snowflake conda channel
316
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
317
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
302
+
318
303
 
319
304
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
320
305
  @telemetry.send_api_usage_telemetry(
@@ -350,7 +335,9 @@ class MultinomialNB(BaseTransformer):
350
335
  # when it is classifier, infer the datatype from label columns
351
336
  if expected_type_inferred == "" and 'predict' in self.model_signatures:
352
337
  # Batch inference takes a single expected output column type. Use the first columns type for now.
353
- label_cols_signatures = [row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols]
338
+ label_cols_signatures = [
339
+ row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
340
+ ]
354
341
  if len(label_cols_signatures) == 0:
355
342
  error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
356
343
  raise exceptions.SnowflakeMLException(
@@ -358,25 +345,23 @@ class MultinomialNB(BaseTransformer):
358
345
  original_exception=ValueError(error_str),
359
346
  )
360
347
 
361
- expected_type_inferred = convert_sp_to_sf_type(
362
- label_cols_signatures[0].as_snowpark_type()
363
- )
348
+ expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
364
349
 
365
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
366
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
350
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
351
+ self._deps = self._get_dependencies()
352
+ assert isinstance(
353
+ dataset._session, Session
354
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
367
355
 
368
356
  transform_kwargs = dict(
369
- session = dataset._session,
370
- dependencies = self._deps,
371
- drop_input_cols = self._drop_input_cols,
372
- expected_output_cols_type = expected_type_inferred,
357
+ session=dataset._session,
358
+ dependencies=self._deps,
359
+ drop_input_cols=self._drop_input_cols,
360
+ expected_output_cols_type=expected_type_inferred,
373
361
  )
374
362
 
375
363
  elif isinstance(dataset, pd.DataFrame):
376
- transform_kwargs = dict(
377
- snowpark_input_cols = self._snowpark_cols,
378
- drop_input_cols = self._drop_input_cols
379
- )
364
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
380
365
 
381
366
  transform_handlers = ModelTransformerBuilder.build(
382
367
  dataset=dataset,
@@ -416,7 +401,7 @@ class MultinomialNB(BaseTransformer):
416
401
  Transformed dataset.
417
402
  """
418
403
  super()._check_dataset_type(dataset)
419
- inference_method="transform"
404
+ inference_method = "transform"
420
405
 
421
406
  # This dictionary contains optional kwargs for batch inference. These kwargs
422
407
  # are specific to the type of dataset used.
@@ -446,24 +431,19 @@ class MultinomialNB(BaseTransformer):
446
431
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
447
432
  expected_dtype = convert_sp_to_sf_type(output_types[0])
448
433
 
449
- self._deps = self._batch_inference_validate_snowpark(
450
- dataset=dataset,
451
- inference_method=inference_method,
452
- )
434
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
435
+ self._deps = self._get_dependencies()
453
436
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
454
437
 
455
438
  transform_kwargs = dict(
456
- session = dataset._session,
457
- dependencies = self._deps,
458
- drop_input_cols = self._drop_input_cols,
459
- expected_output_cols_type = expected_dtype,
439
+ session=dataset._session,
440
+ dependencies=self._deps,
441
+ drop_input_cols=self._drop_input_cols,
442
+ expected_output_cols_type=expected_dtype,
460
443
  )
461
444
 
462
445
  elif isinstance(dataset, pd.DataFrame):
463
- transform_kwargs = dict(
464
- snowpark_input_cols = self._snowpark_cols,
465
- drop_input_cols = self._drop_input_cols
466
- )
446
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
467
447
 
468
448
  transform_handlers = ModelTransformerBuilder.build(
469
449
  dataset=dataset,
@@ -482,7 +462,11 @@ class MultinomialNB(BaseTransformer):
482
462
  return output_df
483
463
 
484
464
  @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
485
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_predict_",) -> Union[DataFrame, pd.DataFrame]:
465
+ def fit_predict(
466
+ self,
467
+ dataset: Union[DataFrame, pd.DataFrame],
468
+ output_cols_prefix: str = "fit_predict_",
469
+ ) -> Union[DataFrame, pd.DataFrame]:
486
470
  """ Method not supported for this class.
487
471
 
488
472
 
@@ -507,22 +491,104 @@ class MultinomialNB(BaseTransformer):
507
491
  )
508
492
  output_result, fitted_estimator = model_trainer.train_fit_predict(
509
493
  drop_input_cols=self._drop_input_cols,
510
- expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
494
+ expected_output_cols_list=(
495
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
496
+ ),
511
497
  )
512
498
  self._sklearn_object = fitted_estimator
513
499
  self._is_fitted = True
514
500
  return output_result
515
501
 
502
+
503
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
504
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
505
+ """ Method not supported for this class.
506
+
516
507
 
517
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
518
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
519
- """
508
+ Raises:
509
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
510
+
511
+ Args:
512
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
513
+ Snowpark or Pandas DataFrame.
514
+ output_cols_prefix: Prefix for the response columns
520
515
  Returns:
521
516
  Transformed dataset.
522
517
  """
523
- self.fit(dataset)
524
- assert self._sklearn_object is not None
525
- return self._sklearn_object.embedding_
518
+ self._infer_input_output_cols(dataset)
519
+ super()._check_dataset_type(dataset)
520
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
521
+ estimator=self._sklearn_object,
522
+ dataset=dataset,
523
+ input_cols=self.input_cols,
524
+ label_cols=self.label_cols,
525
+ sample_weight_col=self.sample_weight_col,
526
+ autogenerated=self._autogenerated,
527
+ subproject=_SUBPROJECT,
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=self.output_cols,
532
+ )
533
+ self._sklearn_object = fitted_estimator
534
+ self._is_fitted = True
535
+ return output_result
536
+
537
+
538
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
539
+ """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
540
+ Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
541
+ """
542
+ output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
543
+ # The following condition is introduced for kneighbors methods, and not used in other methods
544
+ if output_cols:
545
+ output_cols = [
546
+ identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
547
+ for c in output_cols
548
+ ]
549
+ elif getattr(self._sklearn_object, "classes_", None) is None:
550
+ output_cols = [output_cols_prefix]
551
+ elif self._sklearn_object is not None:
552
+ classes = self._sklearn_object.classes_
553
+ if isinstance(classes, numpy.ndarray):
554
+ output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
555
+ elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
556
+ # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
557
+ output_cols = []
558
+ for i, cl in enumerate(classes):
559
+ # For binary classification, there is only one output column for each class
560
+ # ndarray as the two classes are complementary.
561
+ if len(cl) == 2:
562
+ output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
563
+ else:
564
+ output_cols.extend([
565
+ f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
566
+ ])
567
+ else:
568
+ output_cols = []
569
+
570
+ # Make sure column names are valid snowflake identifiers.
571
+ assert output_cols is not None # Make MyPy happy
572
+ rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
573
+
574
+ return rv
575
+
576
+ def _align_expected_output_names(
577
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
578
+ ) -> List[str]:
579
+ # in case the inferred output column names dimension is different
580
+ # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
581
+ output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
582
+ output_df_columns = list(output_df_pd.columns)
583
+ output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
584
+ if self.sample_weight_col:
585
+ output_df_columns_set -= set(self.sample_weight_col)
586
+ # if the dimension of inferred output column names is correct; use it
587
+ if len(expected_output_cols_list) == len(output_df_columns_set):
588
+ return expected_output_cols_list
589
+ # otherwise, use the sklearn estimator's output
590
+ else:
591
+ return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
526
592
 
527
593
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
528
594
  @telemetry.send_api_usage_telemetry(
@@ -556,24 +622,26 @@ class MultinomialNB(BaseTransformer):
556
622
  # are specific to the type of dataset used.
557
623
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
558
624
 
625
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
626
+
559
627
  if isinstance(dataset, DataFrame):
560
- self._deps = self._batch_inference_validate_snowpark(
561
- dataset=dataset,
562
- inference_method=inference_method,
563
- )
564
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
628
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
629
+ self._deps = self._get_dependencies()
630
+ assert isinstance(
631
+ dataset._session, Session
632
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
565
633
  transform_kwargs = dict(
566
634
  session=dataset._session,
567
635
  dependencies=self._deps,
568
- drop_input_cols = self._drop_input_cols,
636
+ drop_input_cols=self._drop_input_cols,
569
637
  expected_output_cols_type="float",
570
638
  )
639
+ expected_output_cols = self._align_expected_output_names(
640
+ inference_method, dataset, expected_output_cols, output_cols_prefix
641
+ )
571
642
 
572
643
  elif isinstance(dataset, pd.DataFrame):
573
- transform_kwargs = dict(
574
- snowpark_input_cols = self._snowpark_cols,
575
- drop_input_cols = self._drop_input_cols
576
- )
644
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
577
645
 
578
646
  transform_handlers = ModelTransformerBuilder.build(
579
647
  dataset=dataset,
@@ -585,7 +653,7 @@ class MultinomialNB(BaseTransformer):
585
653
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
586
654
  inference_method=inference_method,
587
655
  input_cols=self.input_cols,
588
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
656
+ expected_output_cols=expected_output_cols,
589
657
  **transform_kwargs
590
658
  )
591
659
  return output_df
@@ -617,29 +685,30 @@ class MultinomialNB(BaseTransformer):
617
685
  Output dataset with log probability of the sample for each class in the model.
618
686
  """
619
687
  super()._check_dataset_type(dataset)
620
- inference_method="predict_log_proba"
688
+ inference_method = "predict_log_proba"
689
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
621
690
 
622
691
  # This dictionary contains optional kwargs for batch inference. These kwargs
623
692
  # are specific to the type of dataset used.
624
693
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
625
694
 
626
695
  if isinstance(dataset, DataFrame):
627
- self._deps = self._batch_inference_validate_snowpark(
628
- dataset=dataset,
629
- inference_method=inference_method,
630
- )
631
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
696
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
697
+ self._deps = self._get_dependencies()
698
+ assert isinstance(
699
+ dataset._session, Session
700
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
632
701
  transform_kwargs = dict(
633
702
  session=dataset._session,
634
703
  dependencies=self._deps,
635
- drop_input_cols = self._drop_input_cols,
704
+ drop_input_cols=self._drop_input_cols,
636
705
  expected_output_cols_type="float",
637
706
  )
707
+ expected_output_cols = self._align_expected_output_names(
708
+ inference_method, dataset, expected_output_cols, output_cols_prefix
709
+ )
638
710
  elif isinstance(dataset, pd.DataFrame):
639
- transform_kwargs = dict(
640
- snowpark_input_cols = self._snowpark_cols,
641
- drop_input_cols = self._drop_input_cols
642
- )
711
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
643
712
 
644
713
  transform_handlers = ModelTransformerBuilder.build(
645
714
  dataset=dataset,
@@ -652,7 +721,7 @@ class MultinomialNB(BaseTransformer):
652
721
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
653
722
  inference_method=inference_method,
654
723
  input_cols=self.input_cols,
655
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
724
+ expected_output_cols=expected_output_cols,
656
725
  **transform_kwargs
657
726
  )
658
727
  return output_df
@@ -678,30 +747,32 @@ class MultinomialNB(BaseTransformer):
678
747
  Output dataset with results of the decision function for the samples in input dataset.
679
748
  """
680
749
  super()._check_dataset_type(dataset)
681
- inference_method="decision_function"
750
+ inference_method = "decision_function"
682
751
 
683
752
  # This dictionary contains optional kwargs for batch inference. These kwargs
684
753
  # are specific to the type of dataset used.
685
754
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
686
755
 
756
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
757
+
687
758
  if isinstance(dataset, DataFrame):
688
- self._deps = self._batch_inference_validate_snowpark(
689
- dataset=dataset,
690
- inference_method=inference_method,
691
- )
692
- assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
759
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
760
+ self._deps = self._get_dependencies()
761
+ assert isinstance(
762
+ dataset._session, Session
763
+ ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
693
764
  transform_kwargs = dict(
694
765
  session=dataset._session,
695
766
  dependencies=self._deps,
696
- drop_input_cols = self._drop_input_cols,
767
+ drop_input_cols=self._drop_input_cols,
697
768
  expected_output_cols_type="float",
698
769
  )
770
+ expected_output_cols = self._align_expected_output_names(
771
+ inference_method, dataset, expected_output_cols, output_cols_prefix
772
+ )
699
773
 
700
774
  elif isinstance(dataset, pd.DataFrame):
701
- transform_kwargs = dict(
702
- snowpark_input_cols = self._snowpark_cols,
703
- drop_input_cols = self._drop_input_cols
704
- )
775
+ transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
705
776
 
706
777
  transform_handlers = ModelTransformerBuilder.build(
707
778
  dataset=dataset,
@@ -714,7 +785,7 @@ class MultinomialNB(BaseTransformer):
714
785
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
715
786
  inference_method=inference_method,
716
787
  input_cols=self.input_cols,
717
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
788
+ expected_output_cols=expected_output_cols,
718
789
  **transform_kwargs
719
790
  )
720
791
  return output_df
@@ -743,17 +814,17 @@ class MultinomialNB(BaseTransformer):
743
814
  Output dataset with probability of the sample for each class in the model.
744
815
  """
745
816
  super()._check_dataset_type(dataset)
746
- inference_method="score_samples"
817
+ inference_method = "score_samples"
747
818
 
748
819
  # This dictionary contains optional kwargs for batch inference. These kwargs
749
820
  # are specific to the type of dataset used.
750
821
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
751
822
 
823
+ expected_output_cols = self._get_output_column_names(output_cols_prefix)
824
+
752
825
  if isinstance(dataset, DataFrame):
753
- self._deps = self._batch_inference_validate_snowpark(
754
- dataset=dataset,
755
- inference_method=inference_method,
756
- )
826
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
827
+ self._deps = self._get_dependencies()
757
828
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
758
829
  transform_kwargs = dict(
759
830
  session=dataset._session,
@@ -761,6 +832,9 @@ class MultinomialNB(BaseTransformer):
761
832
  drop_input_cols = self._drop_input_cols,
762
833
  expected_output_cols_type="float",
763
834
  )
835
+ expected_output_cols = self._align_expected_output_names(
836
+ inference_method, dataset, expected_output_cols, output_cols_prefix
837
+ )
764
838
 
765
839
  elif isinstance(dataset, pd.DataFrame):
766
840
  transform_kwargs = dict(
@@ -779,7 +853,7 @@ class MultinomialNB(BaseTransformer):
779
853
  output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
780
854
  inference_method=inference_method,
781
855
  input_cols=self.input_cols,
782
- expected_output_cols=self._get_output_column_names(output_cols_prefix),
856
+ expected_output_cols=expected_output_cols,
783
857
  **transform_kwargs
784
858
  )
785
859
  return output_df
@@ -814,17 +888,15 @@ class MultinomialNB(BaseTransformer):
814
888
  transform_kwargs: ScoreKwargsTypedDict = dict()
815
889
 
816
890
  if isinstance(dataset, DataFrame):
817
- self._deps = self._batch_inference_validate_snowpark(
818
- dataset=dataset,
819
- inference_method="score",
820
- )
891
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
892
+ self._deps = self._get_dependencies()
821
893
  selected_cols = self._get_active_columns()
822
894
  if len(selected_cols) > 0:
823
895
  dataset = dataset.select(selected_cols)
824
896
  assert isinstance(dataset._session, Session) # keep mypy happy
825
897
  transform_kwargs = dict(
826
898
  session=dataset._session,
827
- dependencies=["snowflake-snowpark-python"] + self._deps,
899
+ dependencies=self._deps,
828
900
  score_sproc_imports=['sklearn'],
829
901
  )
830
902
  elif isinstance(dataset, pd.DataFrame):
@@ -889,11 +961,8 @@ class MultinomialNB(BaseTransformer):
889
961
 
890
962
  if isinstance(dataset, DataFrame):
891
963
 
892
- self._deps = self._batch_inference_validate_snowpark(
893
- dataset=dataset,
894
- inference_method=inference_method,
895
-
896
- )
964
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
965
+ self._deps = self._get_dependencies()
897
966
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
898
967
  transform_kwargs = dict(
899
968
  session = dataset._session,
@@ -926,50 +995,84 @@ class MultinomialNB(BaseTransformer):
926
995
  )
927
996
  return output_df
928
997
 
998
+
999
+
1000
+ def to_sklearn(self) -> Any:
1001
+ """Get sklearn.naive_bayes.MultinomialNB object.
1002
+ """
1003
+ if self._sklearn_object is None:
1004
+ self._sklearn_object = self._create_sklearn_object()
1005
+ return self._sklearn_object
1006
+
1007
+ def to_xgboost(self) -> Any:
1008
+ raise exceptions.SnowflakeMLException(
1009
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1010
+ original_exception=AttributeError(
1011
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1012
+ "to_xgboost()",
1013
+ "to_sklearn()"
1014
+ )
1015
+ ),
1016
+ )
1017
+
1018
+ def to_lightgbm(self) -> Any:
1019
+ raise exceptions.SnowflakeMLException(
1020
+ error_code=error_codes.METHOD_NOT_ALLOWED,
1021
+ original_exception=AttributeError(
1022
+ modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1023
+ "to_lightgbm()",
1024
+ "to_sklearn()"
1025
+ )
1026
+ ),
1027
+ )
1028
+
1029
+ def _get_dependencies(self) -> List[str]:
1030
+ return self._deps
1031
+
929
1032
 
930
- def _get_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
1033
+ def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
931
1034
  self._model_signature_dict = dict()
932
1035
 
933
1036
  PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
934
1037
 
935
- inputs = list(_infer_signature(dataset[self.input_cols], "input"))
1038
+ inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
936
1039
  outputs: List[BaseFeatureSpec] = []
937
1040
  if hasattr(self, "predict"):
938
1041
  # keep mypy happy
939
- assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
1042
+ assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
940
1043
  # For classifier, the type of predict is the same as the type of label
941
- if self._sklearn_object._estimator_type == 'classifier':
942
- # label columns is the desired type for output
1044
+ if self._sklearn_object._estimator_type == "classifier":
1045
+ # label columns is the desired type for output
943
1046
  outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
944
1047
  # rename the output columns
945
1048
  outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
946
- self._model_signature_dict["predict"] = ModelSignature(inputs,
947
- ([] if self._drop_input_cols else inputs)
948
- + outputs)
1049
+ self._model_signature_dict["predict"] = ModelSignature(
1050
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1051
+ )
949
1052
  # For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
950
1053
  # For outlier models, returns -1 for outliers and 1 for inliers.
951
- # Clusterer returns int64 cluster labels.
1054
+ # Clusterer returns int64 cluster labels.
952
1055
  elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
953
1056
  outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
954
- self._model_signature_dict["predict"] = ModelSignature(inputs,
955
- ([] if self._drop_input_cols else inputs)
956
- + outputs)
957
-
1057
+ self._model_signature_dict["predict"] = ModelSignature(
1058
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1059
+ )
1060
+
958
1061
  # For regressor, the type of predict is float64
959
- elif self._sklearn_object._estimator_type == 'regressor':
1062
+ elif self._sklearn_object._estimator_type == "regressor":
960
1063
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
961
- self._model_signature_dict["predict"] = ModelSignature(inputs,
962
- ([] if self._drop_input_cols else inputs)
963
- + outputs)
964
-
1064
+ self._model_signature_dict["predict"] = ModelSignature(
1065
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1066
+ )
1067
+
965
1068
  for prob_func in PROB_FUNCTIONS:
966
1069
  if hasattr(self, prob_func):
967
1070
  output_cols_prefix: str = f"{prob_func}_"
968
1071
  output_column_names = self._get_output_column_names(output_cols_prefix)
969
1072
  outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
970
- self._model_signature_dict[prob_func] = ModelSignature(inputs,
971
- ([] if self._drop_input_cols else inputs)
972
- + outputs)
1073
+ self._model_signature_dict[prob_func] = ModelSignature(
1074
+ inputs, ([] if self._drop_input_cols else inputs) + outputs
1075
+ )
973
1076
 
974
1077
  # Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
975
1078
  items = list(self._model_signature_dict.items())
@@ -982,10 +1085,10 @@ class MultinomialNB(BaseTransformer):
982
1085
  """Returns model signature of current class.
983
1086
 
984
1087
  Raises:
985
- exceptions.SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
1088
+ SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
986
1089
 
987
1090
  Returns:
988
- Dict[str, ModelSignature]: each method and its input output signature
1091
+ Dict with each method and its input output signature
989
1092
  """
990
1093
  if self._model_signature_dict is None:
991
1094
  raise exceptions.SnowflakeMLException(
@@ -993,35 +1096,3 @@ class MultinomialNB(BaseTransformer):
993
1096
  original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
994
1097
  )
995
1098
  return self._model_signature_dict
996
-
997
- def to_sklearn(self) -> Any:
998
- """Get sklearn.naive_bayes.MultinomialNB object.
999
- """
1000
- if self._sklearn_object is None:
1001
- self._sklearn_object = self._create_sklearn_object()
1002
- return self._sklearn_object
1003
-
1004
- def to_xgboost(self) -> Any:
1005
- raise exceptions.SnowflakeMLException(
1006
- error_code=error_codes.METHOD_NOT_ALLOWED,
1007
- original_exception=AttributeError(
1008
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1009
- "to_xgboost()",
1010
- "to_sklearn()"
1011
- )
1012
- ),
1013
- )
1014
-
1015
- def to_lightgbm(self) -> Any:
1016
- raise exceptions.SnowflakeMLException(
1017
- error_code=error_codes.METHOD_NOT_ALLOWED,
1018
- original_exception=AttributeError(
1019
- modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
1020
- "to_lightgbm()",
1021
- "to_sklearn()"
1022
- )
1023
- ),
1024
- )
1025
-
1026
- def _get_dependencies(self) -> List[str]:
1027
- return self._deps