snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class OutputCodeClassifier(BaseTransformer):
70
64
  r"""(Error-Correcting) Output-Code multiclass strategy
71
65
  For more details on this class, see [sklearn.multiclass.OutputCodeClassifier]
@@ -283,20 +277,17 @@ class OutputCodeClassifier(BaseTransformer):
283
277
  self,
284
278
  dataset: DataFrame,
285
279
  inference_method: str,
286
- ) -> List[str]:
287
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
288
- return the available package that exists in the snowflake anaconda channel
280
+ ) -> None:
281
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
289
282
 
290
283
  Args:
291
284
  dataset: snowpark dataframe
292
285
  inference_method: the inference method such as predict, score...
293
-
286
+
294
287
  Raises:
295
288
  SnowflakeMLException: If the estimator is not fitted, raise error
296
289
  SnowflakeMLException: If the session is None, raise error
297
290
 
298
- Returns:
299
- A list of available package that exists in the snowflake anaconda channel
300
291
  """
301
292
  if not self._is_fitted:
302
293
  raise exceptions.SnowflakeMLException(
@@ -314,9 +305,7 @@ class OutputCodeClassifier(BaseTransformer):
314
305
  "Session must not specified for snowpark dataset."
315
306
  ),
316
307
  )
317
- # Validate that key package version in user workspace are supported in snowflake conda channel
318
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
319
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
308
+
320
309
 
321
310
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
322
311
  @telemetry.send_api_usage_telemetry(
@@ -364,7 +353,8 @@ class OutputCodeClassifier(BaseTransformer):
364
353
 
365
354
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
366
355
 
367
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
356
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
357
+ self._deps = self._get_dependencies()
368
358
  assert isinstance(
369
359
  dataset._session, Session
370
360
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -447,10 +437,8 @@ class OutputCodeClassifier(BaseTransformer):
447
437
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
448
438
  expected_dtype = convert_sp_to_sf_type(output_types[0])
449
439
 
450
- self._deps = self._batch_inference_validate_snowpark(
451
- dataset=dataset,
452
- inference_method=inference_method,
453
- )
440
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
441
+ self._deps = self._get_dependencies()
454
442
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
455
443
 
456
444
  transform_kwargs = dict(
@@ -517,16 +505,40 @@ class OutputCodeClassifier(BaseTransformer):
517
505
  self._is_fitted = True
518
506
  return output_result
519
507
 
508
+
509
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
510
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
511
+ """ Method not supported for this class.
520
512
 
521
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
522
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
523
- """
513
+
514
+ Raises:
515
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
516
+
517
+ Args:
518
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
519
+ Snowpark or Pandas DataFrame.
520
+ output_cols_prefix: Prefix for the response columns
524
521
  Returns:
525
522
  Transformed dataset.
526
523
  """
527
- self.fit(dataset)
528
- assert self._sklearn_object is not None
529
- return self._sklearn_object.embedding_
524
+ self._infer_input_output_cols(dataset)
525
+ super()._check_dataset_type(dataset)
526
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
527
+ estimator=self._sklearn_object,
528
+ dataset=dataset,
529
+ input_cols=self.input_cols,
530
+ label_cols=self.label_cols,
531
+ sample_weight_col=self.sample_weight_col,
532
+ autogenerated=self._autogenerated,
533
+ subproject=_SUBPROJECT,
534
+ )
535
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=self.output_cols,
538
+ )
539
+ self._sklearn_object = fitted_estimator
540
+ self._is_fitted = True
541
+ return output_result
530
542
 
531
543
 
532
544
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -617,10 +629,8 @@ class OutputCodeClassifier(BaseTransformer):
617
629
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
618
630
 
619
631
  if isinstance(dataset, DataFrame):
620
- self._deps = self._batch_inference_validate_snowpark(
621
- dataset=dataset,
622
- inference_method=inference_method,
623
- )
632
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
633
+ self._deps = self._get_dependencies()
624
634
  assert isinstance(
625
635
  dataset._session, Session
626
636
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -685,10 +695,8 @@ class OutputCodeClassifier(BaseTransformer):
685
695
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
686
696
 
687
697
  if isinstance(dataset, DataFrame):
688
- self._deps = self._batch_inference_validate_snowpark(
689
- dataset=dataset,
690
- inference_method=inference_method,
691
- )
698
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
699
+ self._deps = self._get_dependencies()
692
700
  assert isinstance(
693
701
  dataset._session, Session
694
702
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -750,10 +758,8 @@ class OutputCodeClassifier(BaseTransformer):
750
758
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
751
759
 
752
760
  if isinstance(dataset, DataFrame):
753
- self._deps = self._batch_inference_validate_snowpark(
754
- dataset=dataset,
755
- inference_method=inference_method,
756
- )
761
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
762
+ self._deps = self._get_dependencies()
757
763
  assert isinstance(
758
764
  dataset._session, Session
759
765
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -819,10 +825,8 @@ class OutputCodeClassifier(BaseTransformer):
819
825
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
820
826
 
821
827
  if isinstance(dataset, DataFrame):
822
- self._deps = self._batch_inference_validate_snowpark(
823
- dataset=dataset,
824
- inference_method=inference_method,
825
- )
828
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
829
+ self._deps = self._get_dependencies()
826
830
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
827
831
  transform_kwargs = dict(
828
832
  session=dataset._session,
@@ -886,17 +890,15 @@ class OutputCodeClassifier(BaseTransformer):
886
890
  transform_kwargs: ScoreKwargsTypedDict = dict()
887
891
 
888
892
  if isinstance(dataset, DataFrame):
889
- self._deps = self._batch_inference_validate_snowpark(
890
- dataset=dataset,
891
- inference_method="score",
892
- )
893
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
894
+ self._deps = self._get_dependencies()
893
895
  selected_cols = self._get_active_columns()
894
896
  if len(selected_cols) > 0:
895
897
  dataset = dataset.select(selected_cols)
896
898
  assert isinstance(dataset._session, Session) # keep mypy happy
897
899
  transform_kwargs = dict(
898
900
  session=dataset._session,
899
- dependencies=["snowflake-snowpark-python"] + self._deps,
901
+ dependencies=self._deps,
900
902
  score_sproc_imports=['sklearn'],
901
903
  )
902
904
  elif isinstance(dataset, pd.DataFrame):
@@ -961,11 +963,8 @@ class OutputCodeClassifier(BaseTransformer):
961
963
 
962
964
  if isinstance(dataset, DataFrame):
963
965
 
964
- self._deps = self._batch_inference_validate_snowpark(
965
- dataset=dataset,
966
- inference_method=inference_method,
967
-
968
- )
966
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
967
+ self._deps = self._get_dependencies()
969
968
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
970
969
  transform_kwargs = dict(
971
970
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class BernoulliNB(BaseTransformer):
70
64
  r"""Naive Bayes classifier for multivariate Bernoulli models
71
65
  For more details on this class, see [sklearn.naive_bayes.BernoulliNB]
@@ -283,20 +277,17 @@ class BernoulliNB(BaseTransformer):
283
277
  self,
284
278
  dataset: DataFrame,
285
279
  inference_method: str,
286
- ) -> List[str]:
287
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
288
- return the available package that exists in the snowflake anaconda channel
280
+ ) -> None:
281
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
289
282
 
290
283
  Args:
291
284
  dataset: snowpark dataframe
292
285
  inference_method: the inference method such as predict, score...
293
-
286
+
294
287
  Raises:
295
288
  SnowflakeMLException: If the estimator is not fitted, raise error
296
289
  SnowflakeMLException: If the session is None, raise error
297
290
 
298
- Returns:
299
- A list of available package that exists in the snowflake anaconda channel
300
291
  """
301
292
  if not self._is_fitted:
302
293
  raise exceptions.SnowflakeMLException(
@@ -314,9 +305,7 @@ class BernoulliNB(BaseTransformer):
314
305
  "Session must not specified for snowpark dataset."
315
306
  ),
316
307
  )
317
- # Validate that key package version in user workspace are supported in snowflake conda channel
318
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
319
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
308
+
320
309
 
321
310
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
322
311
  @telemetry.send_api_usage_telemetry(
@@ -364,7 +353,8 @@ class BernoulliNB(BaseTransformer):
364
353
 
365
354
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
366
355
 
367
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
356
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
357
+ self._deps = self._get_dependencies()
368
358
  assert isinstance(
369
359
  dataset._session, Session
370
360
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -447,10 +437,8 @@ class BernoulliNB(BaseTransformer):
447
437
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
448
438
  expected_dtype = convert_sp_to_sf_type(output_types[0])
449
439
 
450
- self._deps = self._batch_inference_validate_snowpark(
451
- dataset=dataset,
452
- inference_method=inference_method,
453
- )
440
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
441
+ self._deps = self._get_dependencies()
454
442
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
455
443
 
456
444
  transform_kwargs = dict(
@@ -517,16 +505,40 @@ class BernoulliNB(BaseTransformer):
517
505
  self._is_fitted = True
518
506
  return output_result
519
507
 
508
+
509
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
510
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
511
+ """ Method not supported for this class.
520
512
 
521
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
522
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
523
- """
513
+
514
+ Raises:
515
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
516
+
517
+ Args:
518
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
519
+ Snowpark or Pandas DataFrame.
520
+ output_cols_prefix: Prefix for the response columns
524
521
  Returns:
525
522
  Transformed dataset.
526
523
  """
527
- self.fit(dataset)
528
- assert self._sklearn_object is not None
529
- return self._sklearn_object.embedding_
524
+ self._infer_input_output_cols(dataset)
525
+ super()._check_dataset_type(dataset)
526
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
527
+ estimator=self._sklearn_object,
528
+ dataset=dataset,
529
+ input_cols=self.input_cols,
530
+ label_cols=self.label_cols,
531
+ sample_weight_col=self.sample_weight_col,
532
+ autogenerated=self._autogenerated,
533
+ subproject=_SUBPROJECT,
534
+ )
535
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=self.output_cols,
538
+ )
539
+ self._sklearn_object = fitted_estimator
540
+ self._is_fitted = True
541
+ return output_result
530
542
 
531
543
 
532
544
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -619,10 +631,8 @@ class BernoulliNB(BaseTransformer):
619
631
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
620
632
 
621
633
  if isinstance(dataset, DataFrame):
622
- self._deps = self._batch_inference_validate_snowpark(
623
- dataset=dataset,
624
- inference_method=inference_method,
625
- )
634
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
635
+ self._deps = self._get_dependencies()
626
636
  assert isinstance(
627
637
  dataset._session, Session
628
638
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -689,10 +699,8 @@ class BernoulliNB(BaseTransformer):
689
699
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
690
700
 
691
701
  if isinstance(dataset, DataFrame):
692
- self._deps = self._batch_inference_validate_snowpark(
693
- dataset=dataset,
694
- inference_method=inference_method,
695
- )
702
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
703
+ self._deps = self._get_dependencies()
696
704
  assert isinstance(
697
705
  dataset._session, Session
698
706
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -754,10 +762,8 @@ class BernoulliNB(BaseTransformer):
754
762
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
755
763
 
756
764
  if isinstance(dataset, DataFrame):
757
- self._deps = self._batch_inference_validate_snowpark(
758
- dataset=dataset,
759
- inference_method=inference_method,
760
- )
765
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
766
+ self._deps = self._get_dependencies()
761
767
  assert isinstance(
762
768
  dataset._session, Session
763
769
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -823,10 +829,8 @@ class BernoulliNB(BaseTransformer):
823
829
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
824
830
 
825
831
  if isinstance(dataset, DataFrame):
826
- self._deps = self._batch_inference_validate_snowpark(
827
- dataset=dataset,
828
- inference_method=inference_method,
829
- )
832
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
833
+ self._deps = self._get_dependencies()
830
834
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
831
835
  transform_kwargs = dict(
832
836
  session=dataset._session,
@@ -890,17 +894,15 @@ class BernoulliNB(BaseTransformer):
890
894
  transform_kwargs: ScoreKwargsTypedDict = dict()
891
895
 
892
896
  if isinstance(dataset, DataFrame):
893
- self._deps = self._batch_inference_validate_snowpark(
894
- dataset=dataset,
895
- inference_method="score",
896
- )
897
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
898
+ self._deps = self._get_dependencies()
897
899
  selected_cols = self._get_active_columns()
898
900
  if len(selected_cols) > 0:
899
901
  dataset = dataset.select(selected_cols)
900
902
  assert isinstance(dataset._session, Session) # keep mypy happy
901
903
  transform_kwargs = dict(
902
904
  session=dataset._session,
903
- dependencies=["snowflake-snowpark-python"] + self._deps,
905
+ dependencies=self._deps,
904
906
  score_sproc_imports=['sklearn'],
905
907
  )
906
908
  elif isinstance(dataset, pd.DataFrame):
@@ -965,11 +967,8 @@ class BernoulliNB(BaseTransformer):
965
967
 
966
968
  if isinstance(dataset, DataFrame):
967
969
 
968
- self._deps = self._batch_inference_validate_snowpark(
969
- dataset=dataset,
970
- inference_method=inference_method,
971
-
972
- )
970
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
971
+ self._deps = self._get_dependencies()
973
972
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
974
973
  transform_kwargs = dict(
975
974
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class CategoricalNB(BaseTransformer):
70
64
  r"""Naive Bayes classifier for categorical features
71
65
  For more details on this class, see [sklearn.naive_bayes.CategoricalNB]
@@ -289,20 +283,17 @@ class CategoricalNB(BaseTransformer):
289
283
  self,
290
284
  dataset: DataFrame,
291
285
  inference_method: str,
292
- ) -> List[str]:
293
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
294
- return the available package that exists in the snowflake anaconda channel
286
+ ) -> None:
287
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
295
288
 
296
289
  Args:
297
290
  dataset: snowpark dataframe
298
291
  inference_method: the inference method such as predict, score...
299
-
292
+
300
293
  Raises:
301
294
  SnowflakeMLException: If the estimator is not fitted, raise error
302
295
  SnowflakeMLException: If the session is None, raise error
303
296
 
304
- Returns:
305
- A list of available package that exists in the snowflake anaconda channel
306
297
  """
307
298
  if not self._is_fitted:
308
299
  raise exceptions.SnowflakeMLException(
@@ -320,9 +311,7 @@ class CategoricalNB(BaseTransformer):
320
311
  "Session must not specified for snowpark dataset."
321
312
  ),
322
313
  )
323
- # Validate that key package version in user workspace are supported in snowflake conda channel
324
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
325
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
314
+
326
315
 
327
316
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
328
317
  @telemetry.send_api_usage_telemetry(
@@ -370,7 +359,8 @@ class CategoricalNB(BaseTransformer):
370
359
 
371
360
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
372
361
 
373
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
362
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
363
+ self._deps = self._get_dependencies()
374
364
  assert isinstance(
375
365
  dataset._session, Session
376
366
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -453,10 +443,8 @@ class CategoricalNB(BaseTransformer):
453
443
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
454
444
  expected_dtype = convert_sp_to_sf_type(output_types[0])
455
445
 
456
- self._deps = self._batch_inference_validate_snowpark(
457
- dataset=dataset,
458
- inference_method=inference_method,
459
- )
446
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
447
+ self._deps = self._get_dependencies()
460
448
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
461
449
 
462
450
  transform_kwargs = dict(
@@ -523,16 +511,40 @@ class CategoricalNB(BaseTransformer):
523
511
  self._is_fitted = True
524
512
  return output_result
525
513
 
514
+
515
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
516
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
517
+ """ Method not supported for this class.
526
518
 
527
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
528
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
529
- """
519
+
520
+ Raises:
521
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
522
+
523
+ Args:
524
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
525
+ Snowpark or Pandas DataFrame.
526
+ output_cols_prefix: Prefix for the response columns
530
527
  Returns:
531
528
  Transformed dataset.
532
529
  """
533
- self.fit(dataset)
534
- assert self._sklearn_object is not None
535
- return self._sklearn_object.embedding_
530
+ self._infer_input_output_cols(dataset)
531
+ super()._check_dataset_type(dataset)
532
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
533
+ estimator=self._sklearn_object,
534
+ dataset=dataset,
535
+ input_cols=self.input_cols,
536
+ label_cols=self.label_cols,
537
+ sample_weight_col=self.sample_weight_col,
538
+ autogenerated=self._autogenerated,
539
+ subproject=_SUBPROJECT,
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=self.output_cols,
544
+ )
545
+ self._sklearn_object = fitted_estimator
546
+ self._is_fitted = True
547
+ return output_result
536
548
 
537
549
 
538
550
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -625,10 +637,8 @@ class CategoricalNB(BaseTransformer):
625
637
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
626
638
 
627
639
  if isinstance(dataset, DataFrame):
628
- self._deps = self._batch_inference_validate_snowpark(
629
- dataset=dataset,
630
- inference_method=inference_method,
631
- )
640
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
641
+ self._deps = self._get_dependencies()
632
642
  assert isinstance(
633
643
  dataset._session, Session
634
644
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -695,10 +705,8 @@ class CategoricalNB(BaseTransformer):
695
705
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
696
706
 
697
707
  if isinstance(dataset, DataFrame):
698
- self._deps = self._batch_inference_validate_snowpark(
699
- dataset=dataset,
700
- inference_method=inference_method,
701
- )
708
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
709
+ self._deps = self._get_dependencies()
702
710
  assert isinstance(
703
711
  dataset._session, Session
704
712
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -760,10 +768,8 @@ class CategoricalNB(BaseTransformer):
760
768
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
761
769
 
762
770
  if isinstance(dataset, DataFrame):
763
- self._deps = self._batch_inference_validate_snowpark(
764
- dataset=dataset,
765
- inference_method=inference_method,
766
- )
771
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
772
+ self._deps = self._get_dependencies()
767
773
  assert isinstance(
768
774
  dataset._session, Session
769
775
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -829,10 +835,8 @@ class CategoricalNB(BaseTransformer):
829
835
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
830
836
 
831
837
  if isinstance(dataset, DataFrame):
832
- self._deps = self._batch_inference_validate_snowpark(
833
- dataset=dataset,
834
- inference_method=inference_method,
835
- )
838
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
839
+ self._deps = self._get_dependencies()
836
840
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
837
841
  transform_kwargs = dict(
838
842
  session=dataset._session,
@@ -896,17 +900,15 @@ class CategoricalNB(BaseTransformer):
896
900
  transform_kwargs: ScoreKwargsTypedDict = dict()
897
901
 
898
902
  if isinstance(dataset, DataFrame):
899
- self._deps = self._batch_inference_validate_snowpark(
900
- dataset=dataset,
901
- inference_method="score",
902
- )
903
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
904
+ self._deps = self._get_dependencies()
903
905
  selected_cols = self._get_active_columns()
904
906
  if len(selected_cols) > 0:
905
907
  dataset = dataset.select(selected_cols)
906
908
  assert isinstance(dataset._session, Session) # keep mypy happy
907
909
  transform_kwargs = dict(
908
910
  session=dataset._session,
909
- dependencies=["snowflake-snowpark-python"] + self._deps,
911
+ dependencies=self._deps,
910
912
  score_sproc_imports=['sklearn'],
911
913
  )
912
914
  elif isinstance(dataset, pd.DataFrame):
@@ -971,11 +973,8 @@ class CategoricalNB(BaseTransformer):
971
973
 
972
974
  if isinstance(dataset, DataFrame):
973
975
 
974
- self._deps = self._batch_inference_validate_snowpark(
975
- dataset=dataset,
976
- inference_method=inference_method,
977
-
978
- )
976
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
977
+ self._deps = self._get_dependencies()
979
978
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
980
979
  transform_kwargs = dict(
981
980
  session = dataset._session,