snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ExtraTreesClassifier(BaseTransformer):
70
64
  r"""An extra-trees classifier
71
65
  For more details on this class, see [sklearn.ensemble.ExtraTreesClassifier]
@@ -440,20 +434,17 @@ class ExtraTreesClassifier(BaseTransformer):
440
434
  self,
441
435
  dataset: DataFrame,
442
436
  inference_method: str,
443
- ) -> List[str]:
444
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
445
- return the available package that exists in the snowflake anaconda channel
437
+ ) -> None:
438
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
446
439
 
447
440
  Args:
448
441
  dataset: snowpark dataframe
449
442
  inference_method: the inference method such as predict, score...
450
-
443
+
451
444
  Raises:
452
445
  SnowflakeMLException: If the estimator is not fitted, raise error
453
446
  SnowflakeMLException: If the session is None, raise error
454
447
 
455
- Returns:
456
- A list of available package that exists in the snowflake anaconda channel
457
448
  """
458
449
  if not self._is_fitted:
459
450
  raise exceptions.SnowflakeMLException(
@@ -471,9 +462,7 @@ class ExtraTreesClassifier(BaseTransformer):
471
462
  "Session must not specified for snowpark dataset."
472
463
  ),
473
464
  )
474
- # Validate that key package version in user workspace are supported in snowflake conda channel
475
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
476
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
465
+
477
466
 
478
467
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
479
468
  @telemetry.send_api_usage_telemetry(
@@ -521,7 +510,8 @@ class ExtraTreesClassifier(BaseTransformer):
521
510
 
522
511
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
523
512
 
524
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
513
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
514
+ self._deps = self._get_dependencies()
525
515
  assert isinstance(
526
516
  dataset._session, Session
527
517
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -604,10 +594,8 @@ class ExtraTreesClassifier(BaseTransformer):
604
594
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
605
595
  expected_dtype = convert_sp_to_sf_type(output_types[0])
606
596
 
607
- self._deps = self._batch_inference_validate_snowpark(
608
- dataset=dataset,
609
- inference_method=inference_method,
610
- )
597
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
598
+ self._deps = self._get_dependencies()
611
599
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
612
600
 
613
601
  transform_kwargs = dict(
@@ -674,16 +662,40 @@ class ExtraTreesClassifier(BaseTransformer):
674
662
  self._is_fitted = True
675
663
  return output_result
676
664
 
665
+
666
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
667
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
668
+ """ Method not supported for this class.
677
669
 
678
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
679
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
680
- """
670
+
671
+ Raises:
672
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
673
+
674
+ Args:
675
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
676
+ Snowpark or Pandas DataFrame.
677
+ output_cols_prefix: Prefix for the response columns
681
678
  Returns:
682
679
  Transformed dataset.
683
680
  """
684
- self.fit(dataset)
685
- assert self._sklearn_object is not None
686
- return self._sklearn_object.embedding_
681
+ self._infer_input_output_cols(dataset)
682
+ super()._check_dataset_type(dataset)
683
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
684
+ estimator=self._sklearn_object,
685
+ dataset=dataset,
686
+ input_cols=self.input_cols,
687
+ label_cols=self.label_cols,
688
+ sample_weight_col=self.sample_weight_col,
689
+ autogenerated=self._autogenerated,
690
+ subproject=_SUBPROJECT,
691
+ )
692
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
693
+ drop_input_cols=self._drop_input_cols,
694
+ expected_output_cols_list=self.output_cols,
695
+ )
696
+ self._sklearn_object = fitted_estimator
697
+ self._is_fitted = True
698
+ return output_result
687
699
 
688
700
 
689
701
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -776,10 +788,8 @@ class ExtraTreesClassifier(BaseTransformer):
776
788
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
777
789
 
778
790
  if isinstance(dataset, DataFrame):
779
- self._deps = self._batch_inference_validate_snowpark(
780
- dataset=dataset,
781
- inference_method=inference_method,
782
- )
791
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
792
+ self._deps = self._get_dependencies()
783
793
  assert isinstance(
784
794
  dataset._session, Session
785
795
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -846,10 +856,8 @@ class ExtraTreesClassifier(BaseTransformer):
846
856
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
847
857
 
848
858
  if isinstance(dataset, DataFrame):
849
- self._deps = self._batch_inference_validate_snowpark(
850
- dataset=dataset,
851
- inference_method=inference_method,
852
- )
859
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
860
+ self._deps = self._get_dependencies()
853
861
  assert isinstance(
854
862
  dataset._session, Session
855
863
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -911,10 +919,8 @@ class ExtraTreesClassifier(BaseTransformer):
911
919
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
912
920
 
913
921
  if isinstance(dataset, DataFrame):
914
- self._deps = self._batch_inference_validate_snowpark(
915
- dataset=dataset,
916
- inference_method=inference_method,
917
- )
922
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
923
+ self._deps = self._get_dependencies()
918
924
  assert isinstance(
919
925
  dataset._session, Session
920
926
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -980,10 +986,8 @@ class ExtraTreesClassifier(BaseTransformer):
980
986
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
981
987
 
982
988
  if isinstance(dataset, DataFrame):
983
- self._deps = self._batch_inference_validate_snowpark(
984
- dataset=dataset,
985
- inference_method=inference_method,
986
- )
989
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
990
+ self._deps = self._get_dependencies()
987
991
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
988
992
  transform_kwargs = dict(
989
993
  session=dataset._session,
@@ -1047,17 +1051,15 @@ class ExtraTreesClassifier(BaseTransformer):
1047
1051
  transform_kwargs: ScoreKwargsTypedDict = dict()
1048
1052
 
1049
1053
  if isinstance(dataset, DataFrame):
1050
- self._deps = self._batch_inference_validate_snowpark(
1051
- dataset=dataset,
1052
- inference_method="score",
1053
- )
1054
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1055
+ self._deps = self._get_dependencies()
1054
1056
  selected_cols = self._get_active_columns()
1055
1057
  if len(selected_cols) > 0:
1056
1058
  dataset = dataset.select(selected_cols)
1057
1059
  assert isinstance(dataset._session, Session) # keep mypy happy
1058
1060
  transform_kwargs = dict(
1059
1061
  session=dataset._session,
1060
- dependencies=["snowflake-snowpark-python"] + self._deps,
1062
+ dependencies=self._deps,
1061
1063
  score_sproc_imports=['sklearn'],
1062
1064
  )
1063
1065
  elif isinstance(dataset, pd.DataFrame):
@@ -1122,11 +1124,8 @@ class ExtraTreesClassifier(BaseTransformer):
1122
1124
 
1123
1125
  if isinstance(dataset, DataFrame):
1124
1126
 
1125
- self._deps = self._batch_inference_validate_snowpark(
1126
- dataset=dataset,
1127
- inference_method=inference_method,
1128
-
1129
- )
1127
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1128
+ self._deps = self._get_dependencies()
1130
1129
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1131
1130
  transform_kwargs = dict(
1132
1131
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ExtraTreesRegressor(BaseTransformer):
70
64
  r"""An extra-trees regressor
71
65
  For more details on this class, see [sklearn.ensemble.ExtraTreesRegressor]
@@ -419,20 +413,17 @@ class ExtraTreesRegressor(BaseTransformer):
419
413
  self,
420
414
  dataset: DataFrame,
421
415
  inference_method: str,
422
- ) -> List[str]:
423
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
424
- return the available package that exists in the snowflake anaconda channel
416
+ ) -> None:
417
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
425
418
 
426
419
  Args:
427
420
  dataset: snowpark dataframe
428
421
  inference_method: the inference method such as predict, score...
429
-
422
+
430
423
  Raises:
431
424
  SnowflakeMLException: If the estimator is not fitted, raise error
432
425
  SnowflakeMLException: If the session is None, raise error
433
426
 
434
- Returns:
435
- A list of available package that exists in the snowflake anaconda channel
436
427
  """
437
428
  if not self._is_fitted:
438
429
  raise exceptions.SnowflakeMLException(
@@ -450,9 +441,7 @@ class ExtraTreesRegressor(BaseTransformer):
450
441
  "Session must not specified for snowpark dataset."
451
442
  ),
452
443
  )
453
- # Validate that key package version in user workspace are supported in snowflake conda channel
454
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
455
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
444
+
456
445
 
457
446
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
458
447
  @telemetry.send_api_usage_telemetry(
@@ -500,7 +489,8 @@ class ExtraTreesRegressor(BaseTransformer):
500
489
 
501
490
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
502
491
 
503
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
492
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
493
+ self._deps = self._get_dependencies()
504
494
  assert isinstance(
505
495
  dataset._session, Session
506
496
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -583,10 +573,8 @@ class ExtraTreesRegressor(BaseTransformer):
583
573
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
584
574
  expected_dtype = convert_sp_to_sf_type(output_types[0])
585
575
 
586
- self._deps = self._batch_inference_validate_snowpark(
587
- dataset=dataset,
588
- inference_method=inference_method,
589
- )
576
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
577
+ self._deps = self._get_dependencies()
590
578
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
591
579
 
592
580
  transform_kwargs = dict(
@@ -653,16 +641,40 @@ class ExtraTreesRegressor(BaseTransformer):
653
641
  self._is_fitted = True
654
642
  return output_result
655
643
 
644
+
645
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
646
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
647
+ """ Method not supported for this class.
656
648
 
657
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
658
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
659
- """
649
+
650
+ Raises:
651
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
652
+
653
+ Args:
654
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
655
+ Snowpark or Pandas DataFrame.
656
+ output_cols_prefix: Prefix for the response columns
660
657
  Returns:
661
658
  Transformed dataset.
662
659
  """
663
- self.fit(dataset)
664
- assert self._sklearn_object is not None
665
- return self._sklearn_object.embedding_
660
+ self._infer_input_output_cols(dataset)
661
+ super()._check_dataset_type(dataset)
662
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
663
+ estimator=self._sklearn_object,
664
+ dataset=dataset,
665
+ input_cols=self.input_cols,
666
+ label_cols=self.label_cols,
667
+ sample_weight_col=self.sample_weight_col,
668
+ autogenerated=self._autogenerated,
669
+ subproject=_SUBPROJECT,
670
+ )
671
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
672
+ drop_input_cols=self._drop_input_cols,
673
+ expected_output_cols_list=self.output_cols,
674
+ )
675
+ self._sklearn_object = fitted_estimator
676
+ self._is_fitted = True
677
+ return output_result
666
678
 
667
679
 
668
680
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -753,10 +765,8 @@ class ExtraTreesRegressor(BaseTransformer):
753
765
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
754
766
 
755
767
  if isinstance(dataset, DataFrame):
756
- self._deps = self._batch_inference_validate_snowpark(
757
- dataset=dataset,
758
- inference_method=inference_method,
759
- )
768
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
769
+ self._deps = self._get_dependencies()
760
770
  assert isinstance(
761
771
  dataset._session, Session
762
772
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -821,10 +831,8 @@ class ExtraTreesRegressor(BaseTransformer):
821
831
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
822
832
 
823
833
  if isinstance(dataset, DataFrame):
824
- self._deps = self._batch_inference_validate_snowpark(
825
- dataset=dataset,
826
- inference_method=inference_method,
827
- )
834
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
835
+ self._deps = self._get_dependencies()
828
836
  assert isinstance(
829
837
  dataset._session, Session
830
838
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -886,10 +894,8 @@ class ExtraTreesRegressor(BaseTransformer):
886
894
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
887
895
 
888
896
  if isinstance(dataset, DataFrame):
889
- self._deps = self._batch_inference_validate_snowpark(
890
- dataset=dataset,
891
- inference_method=inference_method,
892
- )
897
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
898
+ self._deps = self._get_dependencies()
893
899
  assert isinstance(
894
900
  dataset._session, Session
895
901
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -955,10 +961,8 @@ class ExtraTreesRegressor(BaseTransformer):
955
961
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
956
962
 
957
963
  if isinstance(dataset, DataFrame):
958
- self._deps = self._batch_inference_validate_snowpark(
959
- dataset=dataset,
960
- inference_method=inference_method,
961
- )
964
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
965
+ self._deps = self._get_dependencies()
962
966
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
963
967
  transform_kwargs = dict(
964
968
  session=dataset._session,
@@ -1022,17 +1026,15 @@ class ExtraTreesRegressor(BaseTransformer):
1022
1026
  transform_kwargs: ScoreKwargsTypedDict = dict()
1023
1027
 
1024
1028
  if isinstance(dataset, DataFrame):
1025
- self._deps = self._batch_inference_validate_snowpark(
1026
- dataset=dataset,
1027
- inference_method="score",
1028
- )
1029
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1030
+ self._deps = self._get_dependencies()
1029
1031
  selected_cols = self._get_active_columns()
1030
1032
  if len(selected_cols) > 0:
1031
1033
  dataset = dataset.select(selected_cols)
1032
1034
  assert isinstance(dataset._session, Session) # keep mypy happy
1033
1035
  transform_kwargs = dict(
1034
1036
  session=dataset._session,
1035
- dependencies=["snowflake-snowpark-python"] + self._deps,
1037
+ dependencies=self._deps,
1036
1038
  score_sproc_imports=['sklearn'],
1037
1039
  )
1038
1040
  elif isinstance(dataset, pd.DataFrame):
@@ -1097,11 +1099,8 @@ class ExtraTreesRegressor(BaseTransformer):
1097
1099
 
1098
1100
  if isinstance(dataset, DataFrame):
1099
1101
 
1100
- self._deps = self._batch_inference_validate_snowpark(
1101
- dataset=dataset,
1102
- inference_method=inference_method,
1103
-
1104
- )
1102
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1103
+ self._deps = self._get_dependencies()
1105
1104
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1106
1105
  transform_kwargs = dict(
1107
1106
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GradientBoostingClassifier(BaseTransformer):
70
64
  r"""Gradient Boosting for classification
71
65
  For more details on this class, see [sklearn.ensemble.GradientBoostingClassifier]
@@ -452,20 +446,17 @@ class GradientBoostingClassifier(BaseTransformer):
452
446
  self,
453
447
  dataset: DataFrame,
454
448
  inference_method: str,
455
- ) -> List[str]:
456
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
457
- return the available package that exists in the snowflake anaconda channel
449
+ ) -> None:
450
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
458
451
 
459
452
  Args:
460
453
  dataset: snowpark dataframe
461
454
  inference_method: the inference method such as predict, score...
462
-
455
+
463
456
  Raises:
464
457
  SnowflakeMLException: If the estimator is not fitted, raise error
465
458
  SnowflakeMLException: If the session is None, raise error
466
459
 
467
- Returns:
468
- A list of available package that exists in the snowflake anaconda channel
469
460
  """
470
461
  if not self._is_fitted:
471
462
  raise exceptions.SnowflakeMLException(
@@ -483,9 +474,7 @@ class GradientBoostingClassifier(BaseTransformer):
483
474
  "Session must not specified for snowpark dataset."
484
475
  ),
485
476
  )
486
- # Validate that key package version in user workspace are supported in snowflake conda channel
487
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
488
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
477
+
489
478
 
490
479
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
491
480
  @telemetry.send_api_usage_telemetry(
@@ -533,7 +522,8 @@ class GradientBoostingClassifier(BaseTransformer):
533
522
 
534
523
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
535
524
 
536
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
525
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
526
+ self._deps = self._get_dependencies()
537
527
  assert isinstance(
538
528
  dataset._session, Session
539
529
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -616,10 +606,8 @@ class GradientBoostingClassifier(BaseTransformer):
616
606
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
617
607
  expected_dtype = convert_sp_to_sf_type(output_types[0])
618
608
 
619
- self._deps = self._batch_inference_validate_snowpark(
620
- dataset=dataset,
621
- inference_method=inference_method,
622
- )
609
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
610
+ self._deps = self._get_dependencies()
623
611
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
624
612
 
625
613
  transform_kwargs = dict(
@@ -686,16 +674,40 @@ class GradientBoostingClassifier(BaseTransformer):
686
674
  self._is_fitted = True
687
675
  return output_result
688
676
 
677
+
678
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
679
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
680
+ """ Method not supported for this class.
689
681
 
690
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
691
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
692
- """
682
+
683
+ Raises:
684
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
685
+
686
+ Args:
687
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
688
+ Snowpark or Pandas DataFrame.
689
+ output_cols_prefix: Prefix for the response columns
693
690
  Returns:
694
691
  Transformed dataset.
695
692
  """
696
- self.fit(dataset)
697
- assert self._sklearn_object is not None
698
- return self._sklearn_object.embedding_
693
+ self._infer_input_output_cols(dataset)
694
+ super()._check_dataset_type(dataset)
695
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
696
+ estimator=self._sklearn_object,
697
+ dataset=dataset,
698
+ input_cols=self.input_cols,
699
+ label_cols=self.label_cols,
700
+ sample_weight_col=self.sample_weight_col,
701
+ autogenerated=self._autogenerated,
702
+ subproject=_SUBPROJECT,
703
+ )
704
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
705
+ drop_input_cols=self._drop_input_cols,
706
+ expected_output_cols_list=self.output_cols,
707
+ )
708
+ self._sklearn_object = fitted_estimator
709
+ self._is_fitted = True
710
+ return output_result
699
711
 
700
712
 
701
713
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -788,10 +800,8 @@ class GradientBoostingClassifier(BaseTransformer):
788
800
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
789
801
 
790
802
  if isinstance(dataset, DataFrame):
791
- self._deps = self._batch_inference_validate_snowpark(
792
- dataset=dataset,
793
- inference_method=inference_method,
794
- )
803
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
804
+ self._deps = self._get_dependencies()
795
805
  assert isinstance(
796
806
  dataset._session, Session
797
807
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -858,10 +868,8 @@ class GradientBoostingClassifier(BaseTransformer):
858
868
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
859
869
 
860
870
  if isinstance(dataset, DataFrame):
861
- self._deps = self._batch_inference_validate_snowpark(
862
- dataset=dataset,
863
- inference_method=inference_method,
864
- )
871
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
872
+ self._deps = self._get_dependencies()
865
873
  assert isinstance(
866
874
  dataset._session, Session
867
875
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -925,10 +933,8 @@ class GradientBoostingClassifier(BaseTransformer):
925
933
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
926
934
 
927
935
  if isinstance(dataset, DataFrame):
928
- self._deps = self._batch_inference_validate_snowpark(
929
- dataset=dataset,
930
- inference_method=inference_method,
931
- )
936
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
937
+ self._deps = self._get_dependencies()
932
938
  assert isinstance(
933
939
  dataset._session, Session
934
940
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -994,10 +1000,8 @@ class GradientBoostingClassifier(BaseTransformer):
994
1000
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
995
1001
 
996
1002
  if isinstance(dataset, DataFrame):
997
- self._deps = self._batch_inference_validate_snowpark(
998
- dataset=dataset,
999
- inference_method=inference_method,
1000
- )
1003
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1004
+ self._deps = self._get_dependencies()
1001
1005
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1002
1006
  transform_kwargs = dict(
1003
1007
  session=dataset._session,
@@ -1061,17 +1065,15 @@ class GradientBoostingClassifier(BaseTransformer):
1061
1065
  transform_kwargs: ScoreKwargsTypedDict = dict()
1062
1066
 
1063
1067
  if isinstance(dataset, DataFrame):
1064
- self._deps = self._batch_inference_validate_snowpark(
1065
- dataset=dataset,
1066
- inference_method="score",
1067
- )
1068
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1069
+ self._deps = self._get_dependencies()
1068
1070
  selected_cols = self._get_active_columns()
1069
1071
  if len(selected_cols) > 0:
1070
1072
  dataset = dataset.select(selected_cols)
1071
1073
  assert isinstance(dataset._session, Session) # keep mypy happy
1072
1074
  transform_kwargs = dict(
1073
1075
  session=dataset._session,
1074
- dependencies=["snowflake-snowpark-python"] + self._deps,
1076
+ dependencies=self._deps,
1075
1077
  score_sproc_imports=['sklearn'],
1076
1078
  )
1077
1079
  elif isinstance(dataset, pd.DataFrame):
@@ -1136,11 +1138,8 @@ class GradientBoostingClassifier(BaseTransformer):
1136
1138
 
1137
1139
  if isinstance(dataset, DataFrame):
1138
1140
 
1139
- self._deps = self._batch_inference_validate_snowpark(
1140
- dataset=dataset,
1141
- inference_method=inference_method,
1142
-
1143
- )
1141
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1142
+ self._deps = self._get_dependencies()
1144
1143
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1145
1144
  transform_kwargs = dict(
1146
1145
  session = dataset._session,