snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class StackingRegressor(BaseTransformer):
70
64
  r"""Stack of estimators with a final regressor
71
65
  For more details on this class, see [sklearn.ensemble.StackingRegressor]
@@ -316,20 +310,17 @@ class StackingRegressor(BaseTransformer):
316
310
  self,
317
311
  dataset: DataFrame,
318
312
  inference_method: str,
319
- ) -> List[str]:
320
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
321
- return the available package that exists in the snowflake anaconda channel
313
+ ) -> None:
314
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
322
315
 
323
316
  Args:
324
317
  dataset: snowpark dataframe
325
318
  inference_method: the inference method such as predict, score...
326
-
319
+
327
320
  Raises:
328
321
  SnowflakeMLException: If the estimator is not fitted, raise error
329
322
  SnowflakeMLException: If the session is None, raise error
330
323
 
331
- Returns:
332
- A list of available package that exists in the snowflake anaconda channel
333
324
  """
334
325
  if not self._is_fitted:
335
326
  raise exceptions.SnowflakeMLException(
@@ -347,9 +338,7 @@ class StackingRegressor(BaseTransformer):
347
338
  "Session must not specified for snowpark dataset."
348
339
  ),
349
340
  )
350
- # Validate that key package version in user workspace are supported in snowflake conda channel
351
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
352
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
341
+
353
342
 
354
343
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
355
344
  @telemetry.send_api_usage_telemetry(
@@ -397,7 +386,8 @@ class StackingRegressor(BaseTransformer):
397
386
 
398
387
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
399
388
 
400
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
389
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
390
+ self._deps = self._get_dependencies()
401
391
  assert isinstance(
402
392
  dataset._session, Session
403
393
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -482,10 +472,8 @@ class StackingRegressor(BaseTransformer):
482
472
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
483
473
  expected_dtype = convert_sp_to_sf_type(output_types[0])
484
474
 
485
- self._deps = self._batch_inference_validate_snowpark(
486
- dataset=dataset,
487
- inference_method=inference_method,
488
- )
475
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
476
+ self._deps = self._get_dependencies()
489
477
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
490
478
 
491
479
  transform_kwargs = dict(
@@ -552,16 +540,42 @@ class StackingRegressor(BaseTransformer):
552
540
  self._is_fitted = True
553
541
  return output_result
554
542
 
543
+
544
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
545
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
546
+ """ Fit the estimators and return the predictions for X for each estimator
547
+ For more details on this function, see [sklearn.ensemble.StackingRegressor.fit_transform]
548
+ (https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.StackingRegressor.html#sklearn.ensemble.StackingRegressor.fit_transform)
549
+
555
550
 
556
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
557
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
558
- """
551
+ Raises:
552
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
553
+
554
+ Args:
555
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
556
+ Snowpark or Pandas DataFrame.
557
+ output_cols_prefix: Prefix for the response columns
559
558
  Returns:
560
559
  Transformed dataset.
561
560
  """
562
- self.fit(dataset)
563
- assert self._sklearn_object is not None
564
- return self._sklearn_object.embedding_
561
+ self._infer_input_output_cols(dataset)
562
+ super()._check_dataset_type(dataset)
563
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
564
+ estimator=self._sklearn_object,
565
+ dataset=dataset,
566
+ input_cols=self.input_cols,
567
+ label_cols=self.label_cols,
568
+ sample_weight_col=self.sample_weight_col,
569
+ autogenerated=self._autogenerated,
570
+ subproject=_SUBPROJECT,
571
+ )
572
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
573
+ drop_input_cols=self._drop_input_cols,
574
+ expected_output_cols_list=self.output_cols,
575
+ )
576
+ self._sklearn_object = fitted_estimator
577
+ self._is_fitted = True
578
+ return output_result
565
579
 
566
580
 
567
581
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -652,10 +666,8 @@ class StackingRegressor(BaseTransformer):
652
666
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
653
667
 
654
668
  if isinstance(dataset, DataFrame):
655
- self._deps = self._batch_inference_validate_snowpark(
656
- dataset=dataset,
657
- inference_method=inference_method,
658
- )
669
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
670
+ self._deps = self._get_dependencies()
659
671
  assert isinstance(
660
672
  dataset._session, Session
661
673
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -720,10 +732,8 @@ class StackingRegressor(BaseTransformer):
720
732
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
721
733
 
722
734
  if isinstance(dataset, DataFrame):
723
- self._deps = self._batch_inference_validate_snowpark(
724
- dataset=dataset,
725
- inference_method=inference_method,
726
- )
735
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
736
+ self._deps = self._get_dependencies()
727
737
  assert isinstance(
728
738
  dataset._session, Session
729
739
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -785,10 +795,8 @@ class StackingRegressor(BaseTransformer):
785
795
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
786
796
 
787
797
  if isinstance(dataset, DataFrame):
788
- self._deps = self._batch_inference_validate_snowpark(
789
- dataset=dataset,
790
- inference_method=inference_method,
791
- )
798
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
799
+ self._deps = self._get_dependencies()
792
800
  assert isinstance(
793
801
  dataset._session, Session
794
802
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -854,10 +862,8 @@ class StackingRegressor(BaseTransformer):
854
862
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
855
863
 
856
864
  if isinstance(dataset, DataFrame):
857
- self._deps = self._batch_inference_validate_snowpark(
858
- dataset=dataset,
859
- inference_method=inference_method,
860
- )
865
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
866
+ self._deps = self._get_dependencies()
861
867
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
862
868
  transform_kwargs = dict(
863
869
  session=dataset._session,
@@ -921,17 +927,15 @@ class StackingRegressor(BaseTransformer):
921
927
  transform_kwargs: ScoreKwargsTypedDict = dict()
922
928
 
923
929
  if isinstance(dataset, DataFrame):
924
- self._deps = self._batch_inference_validate_snowpark(
925
- dataset=dataset,
926
- inference_method="score",
927
- )
930
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
931
+ self._deps = self._get_dependencies()
928
932
  selected_cols = self._get_active_columns()
929
933
  if len(selected_cols) > 0:
930
934
  dataset = dataset.select(selected_cols)
931
935
  assert isinstance(dataset._session, Session) # keep mypy happy
932
936
  transform_kwargs = dict(
933
937
  session=dataset._session,
934
- dependencies=["snowflake-snowpark-python"] + self._deps,
938
+ dependencies=self._deps,
935
939
  score_sproc_imports=['sklearn'],
936
940
  )
937
941
  elif isinstance(dataset, pd.DataFrame):
@@ -996,11 +1000,8 @@ class StackingRegressor(BaseTransformer):
996
1000
 
997
1001
  if isinstance(dataset, DataFrame):
998
1002
 
999
- self._deps = self._batch_inference_validate_snowpark(
1000
- dataset=dataset,
1001
- inference_method=inference_method,
1002
-
1003
- )
1003
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1004
+ self._deps = self._get_dependencies()
1004
1005
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1005
1006
  transform_kwargs = dict(
1006
1007
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class VotingClassifier(BaseTransformer):
70
64
  r"""Soft Voting/Majority Rule classifier for unfitted estimators
71
65
  For more details on this class, see [sklearn.ensemble.VotingClassifier]
@@ -298,20 +292,17 @@ class VotingClassifier(BaseTransformer):
298
292
  self,
299
293
  dataset: DataFrame,
300
294
  inference_method: str,
301
- ) -> List[str]:
302
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
303
- return the available package that exists in the snowflake anaconda channel
295
+ ) -> None:
296
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
304
297
 
305
298
  Args:
306
299
  dataset: snowpark dataframe
307
300
  inference_method: the inference method such as predict, score...
308
-
301
+
309
302
  Raises:
310
303
  SnowflakeMLException: If the estimator is not fitted, raise error
311
304
  SnowflakeMLException: If the session is None, raise error
312
305
 
313
- Returns:
314
- A list of available package that exists in the snowflake anaconda channel
315
306
  """
316
307
  if not self._is_fitted:
317
308
  raise exceptions.SnowflakeMLException(
@@ -329,9 +320,7 @@ class VotingClassifier(BaseTransformer):
329
320
  "Session must not specified for snowpark dataset."
330
321
  ),
331
322
  )
332
- # Validate that key package version in user workspace are supported in snowflake conda channel
333
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
334
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
323
+
335
324
 
336
325
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
337
326
  @telemetry.send_api_usage_telemetry(
@@ -379,7 +368,8 @@ class VotingClassifier(BaseTransformer):
379
368
 
380
369
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
381
370
 
382
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
371
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
372
+ self._deps = self._get_dependencies()
383
373
  assert isinstance(
384
374
  dataset._session, Session
385
375
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -464,10 +454,8 @@ class VotingClassifier(BaseTransformer):
464
454
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
465
455
  expected_dtype = convert_sp_to_sf_type(output_types[0])
466
456
 
467
- self._deps = self._batch_inference_validate_snowpark(
468
- dataset=dataset,
469
- inference_method=inference_method,
470
- )
457
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
458
+ self._deps = self._get_dependencies()
471
459
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
472
460
 
473
461
  transform_kwargs = dict(
@@ -534,16 +522,42 @@ class VotingClassifier(BaseTransformer):
534
522
  self._is_fitted = True
535
523
  return output_result
536
524
 
525
+
526
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
527
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
528
+ """ Return class labels or probabilities for each estimator
529
+ For more details on this function, see [sklearn.ensemble.VotingClassifier.fit_transform]
530
+ (https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html#sklearn.ensemble.VotingClassifier.fit_transform)
531
+
537
532
 
538
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
539
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
540
- """
533
+ Raises:
534
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
535
+
536
+ Args:
537
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
538
+ Snowpark or Pandas DataFrame.
539
+ output_cols_prefix: Prefix for the response columns
541
540
  Returns:
542
541
  Transformed dataset.
543
542
  """
544
- self.fit(dataset)
545
- assert self._sklearn_object is not None
546
- return self._sklearn_object.embedding_
543
+ self._infer_input_output_cols(dataset)
544
+ super()._check_dataset_type(dataset)
545
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
546
+ estimator=self._sklearn_object,
547
+ dataset=dataset,
548
+ input_cols=self.input_cols,
549
+ label_cols=self.label_cols,
550
+ sample_weight_col=self.sample_weight_col,
551
+ autogenerated=self._autogenerated,
552
+ subproject=_SUBPROJECT,
553
+ )
554
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
555
+ drop_input_cols=self._drop_input_cols,
556
+ expected_output_cols_list=self.output_cols,
557
+ )
558
+ self._sklearn_object = fitted_estimator
559
+ self._is_fitted = True
560
+ return output_result
547
561
 
548
562
 
549
563
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -636,10 +650,8 @@ class VotingClassifier(BaseTransformer):
636
650
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
637
651
 
638
652
  if isinstance(dataset, DataFrame):
639
- self._deps = self._batch_inference_validate_snowpark(
640
- dataset=dataset,
641
- inference_method=inference_method,
642
- )
653
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
654
+ self._deps = self._get_dependencies()
643
655
  assert isinstance(
644
656
  dataset._session, Session
645
657
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -706,10 +718,8 @@ class VotingClassifier(BaseTransformer):
706
718
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
707
719
 
708
720
  if isinstance(dataset, DataFrame):
709
- self._deps = self._batch_inference_validate_snowpark(
710
- dataset=dataset,
711
- inference_method=inference_method,
712
- )
721
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
722
+ self._deps = self._get_dependencies()
713
723
  assert isinstance(
714
724
  dataset._session, Session
715
725
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -771,10 +781,8 @@ class VotingClassifier(BaseTransformer):
771
781
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
772
782
 
773
783
  if isinstance(dataset, DataFrame):
774
- self._deps = self._batch_inference_validate_snowpark(
775
- dataset=dataset,
776
- inference_method=inference_method,
777
- )
784
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
785
+ self._deps = self._get_dependencies()
778
786
  assert isinstance(
779
787
  dataset._session, Session
780
788
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -840,10 +848,8 @@ class VotingClassifier(BaseTransformer):
840
848
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
841
849
 
842
850
  if isinstance(dataset, DataFrame):
843
- self._deps = self._batch_inference_validate_snowpark(
844
- dataset=dataset,
845
- inference_method=inference_method,
846
- )
851
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
852
+ self._deps = self._get_dependencies()
847
853
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
848
854
  transform_kwargs = dict(
849
855
  session=dataset._session,
@@ -907,17 +913,15 @@ class VotingClassifier(BaseTransformer):
907
913
  transform_kwargs: ScoreKwargsTypedDict = dict()
908
914
 
909
915
  if isinstance(dataset, DataFrame):
910
- self._deps = self._batch_inference_validate_snowpark(
911
- dataset=dataset,
912
- inference_method="score",
913
- )
916
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
917
+ self._deps = self._get_dependencies()
914
918
  selected_cols = self._get_active_columns()
915
919
  if len(selected_cols) > 0:
916
920
  dataset = dataset.select(selected_cols)
917
921
  assert isinstance(dataset._session, Session) # keep mypy happy
918
922
  transform_kwargs = dict(
919
923
  session=dataset._session,
920
- dependencies=["snowflake-snowpark-python"] + self._deps,
924
+ dependencies=self._deps,
921
925
  score_sproc_imports=['sklearn'],
922
926
  )
923
927
  elif isinstance(dataset, pd.DataFrame):
@@ -982,11 +986,8 @@ class VotingClassifier(BaseTransformer):
982
986
 
983
987
  if isinstance(dataset, DataFrame):
984
988
 
985
- self._deps = self._batch_inference_validate_snowpark(
986
- dataset=dataset,
987
- inference_method=inference_method,
988
-
989
- )
989
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
990
+ self._deps = self._get_dependencies()
990
991
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
991
992
  transform_kwargs = dict(
992
993
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class VotingRegressor(BaseTransformer):
70
64
  r"""Prediction voting regressor for unfitted estimators
71
65
  For more details on this class, see [sklearn.ensemble.VotingRegressor]
@@ -280,20 +274,17 @@ class VotingRegressor(BaseTransformer):
280
274
  self,
281
275
  dataset: DataFrame,
282
276
  inference_method: str,
283
- ) -> List[str]:
284
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
285
- return the available package that exists in the snowflake anaconda channel
277
+ ) -> None:
278
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
286
279
 
287
280
  Args:
288
281
  dataset: snowpark dataframe
289
282
  inference_method: the inference method such as predict, score...
290
-
283
+
291
284
  Raises:
292
285
  SnowflakeMLException: If the estimator is not fitted, raise error
293
286
  SnowflakeMLException: If the session is None, raise error
294
287
 
295
- Returns:
296
- A list of available package that exists in the snowflake anaconda channel
297
288
  """
298
289
  if not self._is_fitted:
299
290
  raise exceptions.SnowflakeMLException(
@@ -311,9 +302,7 @@ class VotingRegressor(BaseTransformer):
311
302
  "Session must not specified for snowpark dataset."
312
303
  ),
313
304
  )
314
- # Validate that key package version in user workspace are supported in snowflake conda channel
315
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
316
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
305
+
317
306
 
318
307
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
319
308
  @telemetry.send_api_usage_telemetry(
@@ -361,7 +350,8 @@ class VotingRegressor(BaseTransformer):
361
350
 
362
351
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
363
352
 
364
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
353
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
354
+ self._deps = self._get_dependencies()
365
355
  assert isinstance(
366
356
  dataset._session, Session
367
357
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -446,10 +436,8 @@ class VotingRegressor(BaseTransformer):
446
436
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
447
437
  expected_dtype = convert_sp_to_sf_type(output_types[0])
448
438
 
449
- self._deps = self._batch_inference_validate_snowpark(
450
- dataset=dataset,
451
- inference_method=inference_method,
452
- )
439
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
440
+ self._deps = self._get_dependencies()
453
441
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
454
442
 
455
443
  transform_kwargs = dict(
@@ -516,16 +504,42 @@ class VotingRegressor(BaseTransformer):
516
504
  self._is_fitted = True
517
505
  return output_result
518
506
 
507
+
508
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
509
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
510
+ """ Return class labels or probabilities for each estimator
511
+ For more details on this function, see [sklearn.ensemble.VotingRegressor.fit_transform]
512
+ (https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingRegressor.html#sklearn.ensemble.VotingRegressor.fit_transform)
513
+
519
514
 
520
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
521
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
522
- """
515
+ Raises:
516
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
517
+
518
+ Args:
519
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
520
+ Snowpark or Pandas DataFrame.
521
+ output_cols_prefix: Prefix for the response columns
523
522
  Returns:
524
523
  Transformed dataset.
525
524
  """
526
- self.fit(dataset)
527
- assert self._sklearn_object is not None
528
- return self._sklearn_object.embedding_
525
+ self._infer_input_output_cols(dataset)
526
+ super()._check_dataset_type(dataset)
527
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
528
+ estimator=self._sklearn_object,
529
+ dataset=dataset,
530
+ input_cols=self.input_cols,
531
+ label_cols=self.label_cols,
532
+ sample_weight_col=self.sample_weight_col,
533
+ autogenerated=self._autogenerated,
534
+ subproject=_SUBPROJECT,
535
+ )
536
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
537
+ drop_input_cols=self._drop_input_cols,
538
+ expected_output_cols_list=self.output_cols,
539
+ )
540
+ self._sklearn_object = fitted_estimator
541
+ self._is_fitted = True
542
+ return output_result
529
543
 
530
544
 
531
545
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -616,10 +630,8 @@ class VotingRegressor(BaseTransformer):
616
630
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
617
631
 
618
632
  if isinstance(dataset, DataFrame):
619
- self._deps = self._batch_inference_validate_snowpark(
620
- dataset=dataset,
621
- inference_method=inference_method,
622
- )
633
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
634
+ self._deps = self._get_dependencies()
623
635
  assert isinstance(
624
636
  dataset._session, Session
625
637
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -684,10 +696,8 @@ class VotingRegressor(BaseTransformer):
684
696
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
685
697
 
686
698
  if isinstance(dataset, DataFrame):
687
- self._deps = self._batch_inference_validate_snowpark(
688
- dataset=dataset,
689
- inference_method=inference_method,
690
- )
699
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
700
+ self._deps = self._get_dependencies()
691
701
  assert isinstance(
692
702
  dataset._session, Session
693
703
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -749,10 +759,8 @@ class VotingRegressor(BaseTransformer):
749
759
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
750
760
 
751
761
  if isinstance(dataset, DataFrame):
752
- self._deps = self._batch_inference_validate_snowpark(
753
- dataset=dataset,
754
- inference_method=inference_method,
755
- )
762
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
763
+ self._deps = self._get_dependencies()
756
764
  assert isinstance(
757
765
  dataset._session, Session
758
766
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -818,10 +826,8 @@ class VotingRegressor(BaseTransformer):
818
826
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
819
827
 
820
828
  if isinstance(dataset, DataFrame):
821
- self._deps = self._batch_inference_validate_snowpark(
822
- dataset=dataset,
823
- inference_method=inference_method,
824
- )
829
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
830
+ self._deps = self._get_dependencies()
825
831
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
826
832
  transform_kwargs = dict(
827
833
  session=dataset._session,
@@ -885,17 +891,15 @@ class VotingRegressor(BaseTransformer):
885
891
  transform_kwargs: ScoreKwargsTypedDict = dict()
886
892
 
887
893
  if isinstance(dataset, DataFrame):
888
- self._deps = self._batch_inference_validate_snowpark(
889
- dataset=dataset,
890
- inference_method="score",
891
- )
894
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
895
+ self._deps = self._get_dependencies()
892
896
  selected_cols = self._get_active_columns()
893
897
  if len(selected_cols) > 0:
894
898
  dataset = dataset.select(selected_cols)
895
899
  assert isinstance(dataset._session, Session) # keep mypy happy
896
900
  transform_kwargs = dict(
897
901
  session=dataset._session,
898
- dependencies=["snowflake-snowpark-python"] + self._deps,
902
+ dependencies=self._deps,
899
903
  score_sproc_imports=['sklearn'],
900
904
  )
901
905
  elif isinstance(dataset, pd.DataFrame):
@@ -960,11 +964,8 @@ class VotingRegressor(BaseTransformer):
960
964
 
961
965
  if isinstance(dataset, DataFrame):
962
966
 
963
- self._deps = self._batch_inference_validate_snowpark(
964
- dataset=dataset,
965
- inference_method=inference_method,
966
-
967
- )
967
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
968
+ self._deps = self._get_dependencies()
968
969
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
969
970
  transform_kwargs = dict(
970
971
  session = dataset._session,