snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ComplementNB(BaseTransformer):
70
64
  r"""The Complement Naive Bayes classifier described in Rennie et al
71
65
  For more details on this class, see [sklearn.naive_bayes.ComplementNB]
@@ -283,20 +277,17 @@ class ComplementNB(BaseTransformer):
283
277
  self,
284
278
  dataset: DataFrame,
285
279
  inference_method: str,
286
- ) -> List[str]:
287
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
288
- return the available package that exists in the snowflake anaconda channel
280
+ ) -> None:
281
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
289
282
 
290
283
  Args:
291
284
  dataset: snowpark dataframe
292
285
  inference_method: the inference method such as predict, score...
293
-
286
+
294
287
  Raises:
295
288
  SnowflakeMLException: If the estimator is not fitted, raise error
296
289
  SnowflakeMLException: If the session is None, raise error
297
290
 
298
- Returns:
299
- A list of available package that exists in the snowflake anaconda channel
300
291
  """
301
292
  if not self._is_fitted:
302
293
  raise exceptions.SnowflakeMLException(
@@ -314,9 +305,7 @@ class ComplementNB(BaseTransformer):
314
305
  "Session must not specified for snowpark dataset."
315
306
  ),
316
307
  )
317
- # Validate that key package version in user workspace are supported in snowflake conda channel
318
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
319
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
308
+
320
309
 
321
310
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
322
311
  @telemetry.send_api_usage_telemetry(
@@ -364,7 +353,8 @@ class ComplementNB(BaseTransformer):
364
353
 
365
354
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
366
355
 
367
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
356
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
357
+ self._deps = self._get_dependencies()
368
358
  assert isinstance(
369
359
  dataset._session, Session
370
360
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -447,10 +437,8 @@ class ComplementNB(BaseTransformer):
447
437
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
448
438
  expected_dtype = convert_sp_to_sf_type(output_types[0])
449
439
 
450
- self._deps = self._batch_inference_validate_snowpark(
451
- dataset=dataset,
452
- inference_method=inference_method,
453
- )
440
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
441
+ self._deps = self._get_dependencies()
454
442
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
455
443
 
456
444
  transform_kwargs = dict(
@@ -517,16 +505,40 @@ class ComplementNB(BaseTransformer):
517
505
  self._is_fitted = True
518
506
  return output_result
519
507
 
508
+
509
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
510
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
511
+ """ Method not supported for this class.
520
512
 
521
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
522
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
523
- """
513
+
514
+ Raises:
515
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
516
+
517
+ Args:
518
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
519
+ Snowpark or Pandas DataFrame.
520
+ output_cols_prefix: Prefix for the response columns
524
521
  Returns:
525
522
  Transformed dataset.
526
523
  """
527
- self.fit(dataset)
528
- assert self._sklearn_object is not None
529
- return self._sklearn_object.embedding_
524
+ self._infer_input_output_cols(dataset)
525
+ super()._check_dataset_type(dataset)
526
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
527
+ estimator=self._sklearn_object,
528
+ dataset=dataset,
529
+ input_cols=self.input_cols,
530
+ label_cols=self.label_cols,
531
+ sample_weight_col=self.sample_weight_col,
532
+ autogenerated=self._autogenerated,
533
+ subproject=_SUBPROJECT,
534
+ )
535
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=self.output_cols,
538
+ )
539
+ self._sklearn_object = fitted_estimator
540
+ self._is_fitted = True
541
+ return output_result
530
542
 
531
543
 
532
544
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -619,10 +631,8 @@ class ComplementNB(BaseTransformer):
619
631
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
620
632
 
621
633
  if isinstance(dataset, DataFrame):
622
- self._deps = self._batch_inference_validate_snowpark(
623
- dataset=dataset,
624
- inference_method=inference_method,
625
- )
634
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
635
+ self._deps = self._get_dependencies()
626
636
  assert isinstance(
627
637
  dataset._session, Session
628
638
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -689,10 +699,8 @@ class ComplementNB(BaseTransformer):
689
699
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
690
700
 
691
701
  if isinstance(dataset, DataFrame):
692
- self._deps = self._batch_inference_validate_snowpark(
693
- dataset=dataset,
694
- inference_method=inference_method,
695
- )
702
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
703
+ self._deps = self._get_dependencies()
696
704
  assert isinstance(
697
705
  dataset._session, Session
698
706
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -754,10 +762,8 @@ class ComplementNB(BaseTransformer):
754
762
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
755
763
 
756
764
  if isinstance(dataset, DataFrame):
757
- self._deps = self._batch_inference_validate_snowpark(
758
- dataset=dataset,
759
- inference_method=inference_method,
760
- )
765
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
766
+ self._deps = self._get_dependencies()
761
767
  assert isinstance(
762
768
  dataset._session, Session
763
769
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -823,10 +829,8 @@ class ComplementNB(BaseTransformer):
823
829
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
824
830
 
825
831
  if isinstance(dataset, DataFrame):
826
- self._deps = self._batch_inference_validate_snowpark(
827
- dataset=dataset,
828
- inference_method=inference_method,
829
- )
832
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
833
+ self._deps = self._get_dependencies()
830
834
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
831
835
  transform_kwargs = dict(
832
836
  session=dataset._session,
@@ -890,17 +894,15 @@ class ComplementNB(BaseTransformer):
890
894
  transform_kwargs: ScoreKwargsTypedDict = dict()
891
895
 
892
896
  if isinstance(dataset, DataFrame):
893
- self._deps = self._batch_inference_validate_snowpark(
894
- dataset=dataset,
895
- inference_method="score",
896
- )
897
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
898
+ self._deps = self._get_dependencies()
897
899
  selected_cols = self._get_active_columns()
898
900
  if len(selected_cols) > 0:
899
901
  dataset = dataset.select(selected_cols)
900
902
  assert isinstance(dataset._session, Session) # keep mypy happy
901
903
  transform_kwargs = dict(
902
904
  session=dataset._session,
903
- dependencies=["snowflake-snowpark-python"] + self._deps,
905
+ dependencies=self._deps,
904
906
  score_sproc_imports=['sklearn'],
905
907
  )
906
908
  elif isinstance(dataset, pd.DataFrame):
@@ -965,11 +967,8 @@ class ComplementNB(BaseTransformer):
965
967
 
966
968
  if isinstance(dataset, DataFrame):
967
969
 
968
- self._deps = self._batch_inference_validate_snowpark(
969
- dataset=dataset,
970
- inference_method=inference_method,
971
-
972
- )
970
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
971
+ self._deps = self._get_dependencies()
973
972
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
974
973
  transform_kwargs = dict(
975
974
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GaussianNB(BaseTransformer):
70
64
  r"""Gaussian Naive Bayes (GaussianNB)
71
65
  For more details on this class, see [sklearn.naive_bayes.GaussianNB]
@@ -264,20 +258,17 @@ class GaussianNB(BaseTransformer):
264
258
  self,
265
259
  dataset: DataFrame,
266
260
  inference_method: str,
267
- ) -> List[str]:
268
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
269
- return the available package that exists in the snowflake anaconda channel
261
+ ) -> None:
262
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
270
263
 
271
264
  Args:
272
265
  dataset: snowpark dataframe
273
266
  inference_method: the inference method such as predict, score...
274
-
267
+
275
268
  Raises:
276
269
  SnowflakeMLException: If the estimator is not fitted, raise error
277
270
  SnowflakeMLException: If the session is None, raise error
278
271
 
279
- Returns:
280
- A list of available package that exists in the snowflake anaconda channel
281
272
  """
282
273
  if not self._is_fitted:
283
274
  raise exceptions.SnowflakeMLException(
@@ -295,9 +286,7 @@ class GaussianNB(BaseTransformer):
295
286
  "Session must not specified for snowpark dataset."
296
287
  ),
297
288
  )
298
- # Validate that key package version in user workspace are supported in snowflake conda channel
299
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
300
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
289
+
301
290
 
302
291
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
303
292
  @telemetry.send_api_usage_telemetry(
@@ -345,7 +334,8 @@ class GaussianNB(BaseTransformer):
345
334
 
346
335
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
347
336
 
348
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
337
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
338
+ self._deps = self._get_dependencies()
349
339
  assert isinstance(
350
340
  dataset._session, Session
351
341
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -428,10 +418,8 @@ class GaussianNB(BaseTransformer):
428
418
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
429
419
  expected_dtype = convert_sp_to_sf_type(output_types[0])
430
420
 
431
- self._deps = self._batch_inference_validate_snowpark(
432
- dataset=dataset,
433
- inference_method=inference_method,
434
- )
421
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
422
+ self._deps = self._get_dependencies()
435
423
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
436
424
 
437
425
  transform_kwargs = dict(
@@ -498,16 +486,40 @@ class GaussianNB(BaseTransformer):
498
486
  self._is_fitted = True
499
487
  return output_result
500
488
 
489
+
490
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
491
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
492
+ """ Method not supported for this class.
501
493
 
502
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
503
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
504
- """
494
+
495
+ Raises:
496
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
497
+
498
+ Args:
499
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
500
+ Snowpark or Pandas DataFrame.
501
+ output_cols_prefix: Prefix for the response columns
505
502
  Returns:
506
503
  Transformed dataset.
507
504
  """
508
- self.fit(dataset)
509
- assert self._sklearn_object is not None
510
- return self._sklearn_object.embedding_
505
+ self._infer_input_output_cols(dataset)
506
+ super()._check_dataset_type(dataset)
507
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
508
+ estimator=self._sklearn_object,
509
+ dataset=dataset,
510
+ input_cols=self.input_cols,
511
+ label_cols=self.label_cols,
512
+ sample_weight_col=self.sample_weight_col,
513
+ autogenerated=self._autogenerated,
514
+ subproject=_SUBPROJECT,
515
+ )
516
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
517
+ drop_input_cols=self._drop_input_cols,
518
+ expected_output_cols_list=self.output_cols,
519
+ )
520
+ self._sklearn_object = fitted_estimator
521
+ self._is_fitted = True
522
+ return output_result
511
523
 
512
524
 
513
525
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -600,10 +612,8 @@ class GaussianNB(BaseTransformer):
600
612
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
601
613
 
602
614
  if isinstance(dataset, DataFrame):
603
- self._deps = self._batch_inference_validate_snowpark(
604
- dataset=dataset,
605
- inference_method=inference_method,
606
- )
615
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
616
+ self._deps = self._get_dependencies()
607
617
  assert isinstance(
608
618
  dataset._session, Session
609
619
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -670,10 +680,8 @@ class GaussianNB(BaseTransformer):
670
680
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
671
681
 
672
682
  if isinstance(dataset, DataFrame):
673
- self._deps = self._batch_inference_validate_snowpark(
674
- dataset=dataset,
675
- inference_method=inference_method,
676
- )
683
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
684
+ self._deps = self._get_dependencies()
677
685
  assert isinstance(
678
686
  dataset._session, Session
679
687
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -735,10 +743,8 @@ class GaussianNB(BaseTransformer):
735
743
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
736
744
 
737
745
  if isinstance(dataset, DataFrame):
738
- self._deps = self._batch_inference_validate_snowpark(
739
- dataset=dataset,
740
- inference_method=inference_method,
741
- )
746
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
747
+ self._deps = self._get_dependencies()
742
748
  assert isinstance(
743
749
  dataset._session, Session
744
750
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -804,10 +810,8 @@ class GaussianNB(BaseTransformer):
804
810
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
805
811
 
806
812
  if isinstance(dataset, DataFrame):
807
- self._deps = self._batch_inference_validate_snowpark(
808
- dataset=dataset,
809
- inference_method=inference_method,
810
- )
813
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
814
+ self._deps = self._get_dependencies()
811
815
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
812
816
  transform_kwargs = dict(
813
817
  session=dataset._session,
@@ -871,17 +875,15 @@ class GaussianNB(BaseTransformer):
871
875
  transform_kwargs: ScoreKwargsTypedDict = dict()
872
876
 
873
877
  if isinstance(dataset, DataFrame):
874
- self._deps = self._batch_inference_validate_snowpark(
875
- dataset=dataset,
876
- inference_method="score",
877
- )
878
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
879
+ self._deps = self._get_dependencies()
878
880
  selected_cols = self._get_active_columns()
879
881
  if len(selected_cols) > 0:
880
882
  dataset = dataset.select(selected_cols)
881
883
  assert isinstance(dataset._session, Session) # keep mypy happy
882
884
  transform_kwargs = dict(
883
885
  session=dataset._session,
884
- dependencies=["snowflake-snowpark-python"] + self._deps,
886
+ dependencies=self._deps,
885
887
  score_sproc_imports=['sklearn'],
886
888
  )
887
889
  elif isinstance(dataset, pd.DataFrame):
@@ -946,11 +948,8 @@ class GaussianNB(BaseTransformer):
946
948
 
947
949
  if isinstance(dataset, DataFrame):
948
950
 
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method=inference_method,
952
-
953
- )
951
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
952
+ self._deps = self._get_dependencies()
954
953
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
955
954
  transform_kwargs = dict(
956
955
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MultinomialNB(BaseTransformer):
70
64
  r"""Naive Bayes classifier for multinomial models
71
65
  For more details on this class, see [sklearn.naive_bayes.MultinomialNB]
@@ -277,20 +271,17 @@ class MultinomialNB(BaseTransformer):
277
271
  self,
278
272
  dataset: DataFrame,
279
273
  inference_method: str,
280
- ) -> List[str]:
281
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
282
- return the available package that exists in the snowflake anaconda channel
274
+ ) -> None:
275
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
283
276
 
284
277
  Args:
285
278
  dataset: snowpark dataframe
286
279
  inference_method: the inference method such as predict, score...
287
-
280
+
288
281
  Raises:
289
282
  SnowflakeMLException: If the estimator is not fitted, raise error
290
283
  SnowflakeMLException: If the session is None, raise error
291
284
 
292
- Returns:
293
- A list of available package that exists in the snowflake anaconda channel
294
285
  """
295
286
  if not self._is_fitted:
296
287
  raise exceptions.SnowflakeMLException(
@@ -308,9 +299,7 @@ class MultinomialNB(BaseTransformer):
308
299
  "Session must not specified for snowpark dataset."
309
300
  ),
310
301
  )
311
- # Validate that key package version in user workspace are supported in snowflake conda channel
312
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
313
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
302
+
314
303
 
315
304
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
316
305
  @telemetry.send_api_usage_telemetry(
@@ -358,7 +347,8 @@ class MultinomialNB(BaseTransformer):
358
347
 
359
348
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
360
349
 
361
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
350
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
351
+ self._deps = self._get_dependencies()
362
352
  assert isinstance(
363
353
  dataset._session, Session
364
354
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -441,10 +431,8 @@ class MultinomialNB(BaseTransformer):
441
431
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
442
432
  expected_dtype = convert_sp_to_sf_type(output_types[0])
443
433
 
444
- self._deps = self._batch_inference_validate_snowpark(
445
- dataset=dataset,
446
- inference_method=inference_method,
447
- )
434
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
435
+ self._deps = self._get_dependencies()
448
436
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
449
437
 
450
438
  transform_kwargs = dict(
@@ -511,16 +499,40 @@ class MultinomialNB(BaseTransformer):
511
499
  self._is_fitted = True
512
500
  return output_result
513
501
 
502
+
503
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
504
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
505
+ """ Method not supported for this class.
514
506
 
515
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
516
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
517
- """
507
+
508
+ Raises:
509
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
510
+
511
+ Args:
512
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
513
+ Snowpark or Pandas DataFrame.
514
+ output_cols_prefix: Prefix for the response columns
518
515
  Returns:
519
516
  Transformed dataset.
520
517
  """
521
- self.fit(dataset)
522
- assert self._sklearn_object is not None
523
- return self._sklearn_object.embedding_
518
+ self._infer_input_output_cols(dataset)
519
+ super()._check_dataset_type(dataset)
520
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
521
+ estimator=self._sklearn_object,
522
+ dataset=dataset,
523
+ input_cols=self.input_cols,
524
+ label_cols=self.label_cols,
525
+ sample_weight_col=self.sample_weight_col,
526
+ autogenerated=self._autogenerated,
527
+ subproject=_SUBPROJECT,
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=self.output_cols,
532
+ )
533
+ self._sklearn_object = fitted_estimator
534
+ self._is_fitted = True
535
+ return output_result
524
536
 
525
537
 
526
538
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -613,10 +625,8 @@ class MultinomialNB(BaseTransformer):
613
625
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
614
626
 
615
627
  if isinstance(dataset, DataFrame):
616
- self._deps = self._batch_inference_validate_snowpark(
617
- dataset=dataset,
618
- inference_method=inference_method,
619
- )
628
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
629
+ self._deps = self._get_dependencies()
620
630
  assert isinstance(
621
631
  dataset._session, Session
622
632
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -683,10 +693,8 @@ class MultinomialNB(BaseTransformer):
683
693
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
684
694
 
685
695
  if isinstance(dataset, DataFrame):
686
- self._deps = self._batch_inference_validate_snowpark(
687
- dataset=dataset,
688
- inference_method=inference_method,
689
- )
696
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
697
+ self._deps = self._get_dependencies()
690
698
  assert isinstance(
691
699
  dataset._session, Session
692
700
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -748,10 +756,8 @@ class MultinomialNB(BaseTransformer):
748
756
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
749
757
 
750
758
  if isinstance(dataset, DataFrame):
751
- self._deps = self._batch_inference_validate_snowpark(
752
- dataset=dataset,
753
- inference_method=inference_method,
754
- )
759
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
760
+ self._deps = self._get_dependencies()
755
761
  assert isinstance(
756
762
  dataset._session, Session
757
763
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -817,10 +823,8 @@ class MultinomialNB(BaseTransformer):
817
823
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
818
824
 
819
825
  if isinstance(dataset, DataFrame):
820
- self._deps = self._batch_inference_validate_snowpark(
821
- dataset=dataset,
822
- inference_method=inference_method,
823
- )
826
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
827
+ self._deps = self._get_dependencies()
824
828
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
825
829
  transform_kwargs = dict(
826
830
  session=dataset._session,
@@ -884,17 +888,15 @@ class MultinomialNB(BaseTransformer):
884
888
  transform_kwargs: ScoreKwargsTypedDict = dict()
885
889
 
886
890
  if isinstance(dataset, DataFrame):
887
- self._deps = self._batch_inference_validate_snowpark(
888
- dataset=dataset,
889
- inference_method="score",
890
- )
891
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
892
+ self._deps = self._get_dependencies()
891
893
  selected_cols = self._get_active_columns()
892
894
  if len(selected_cols) > 0:
893
895
  dataset = dataset.select(selected_cols)
894
896
  assert isinstance(dataset._session, Session) # keep mypy happy
895
897
  transform_kwargs = dict(
896
898
  session=dataset._session,
897
- dependencies=["snowflake-snowpark-python"] + self._deps,
899
+ dependencies=self._deps,
898
900
  score_sproc_imports=['sklearn'],
899
901
  )
900
902
  elif isinstance(dataset, pd.DataFrame):
@@ -959,11 +961,8 @@ class MultinomialNB(BaseTransformer):
959
961
 
960
962
  if isinstance(dataset, DataFrame):
961
963
 
962
- self._deps = self._batch_inference_validate_snowpark(
963
- dataset=dataset,
964
- inference_method=inference_method,
965
-
966
- )
964
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
965
+ self._deps = self._get_dependencies()
967
966
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
968
967
  transform_kwargs = dict(
969
968
  session = dataset._session,