snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GradientBoostingRegressor(BaseTransformer):
70
64
  r"""Gradient Boosting for regression
71
65
  For more details on this class, see [sklearn.ensemble.GradientBoostingRegressor]
@@ -461,20 +455,17 @@ class GradientBoostingRegressor(BaseTransformer):
461
455
  self,
462
456
  dataset: DataFrame,
463
457
  inference_method: str,
464
- ) -> List[str]:
465
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
466
- return the available package that exists in the snowflake anaconda channel
458
+ ) -> None:
459
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
467
460
 
468
461
  Args:
469
462
  dataset: snowpark dataframe
470
463
  inference_method: the inference method such as predict, score...
471
-
464
+
472
465
  Raises:
473
466
  SnowflakeMLException: If the estimator is not fitted, raise error
474
467
  SnowflakeMLException: If the session is None, raise error
475
468
 
476
- Returns:
477
- A list of available package that exists in the snowflake anaconda channel
478
469
  """
479
470
  if not self._is_fitted:
480
471
  raise exceptions.SnowflakeMLException(
@@ -492,9 +483,7 @@ class GradientBoostingRegressor(BaseTransformer):
492
483
  "Session must not specified for snowpark dataset."
493
484
  ),
494
485
  )
495
- # Validate that key package version in user workspace are supported in snowflake conda channel
496
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
497
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
486
+
498
487
 
499
488
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
500
489
  @telemetry.send_api_usage_telemetry(
@@ -542,7 +531,8 @@ class GradientBoostingRegressor(BaseTransformer):
542
531
 
543
532
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
544
533
 
545
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
534
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
535
+ self._deps = self._get_dependencies()
546
536
  assert isinstance(
547
537
  dataset._session, Session
548
538
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -625,10 +615,8 @@ class GradientBoostingRegressor(BaseTransformer):
625
615
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
626
616
  expected_dtype = convert_sp_to_sf_type(output_types[0])
627
617
 
628
- self._deps = self._batch_inference_validate_snowpark(
629
- dataset=dataset,
630
- inference_method=inference_method,
631
- )
618
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
619
+ self._deps = self._get_dependencies()
632
620
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
633
621
 
634
622
  transform_kwargs = dict(
@@ -695,16 +683,40 @@ class GradientBoostingRegressor(BaseTransformer):
695
683
  self._is_fitted = True
696
684
  return output_result
697
685
 
686
+
687
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
688
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
689
+ """ Method not supported for this class.
698
690
 
699
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
700
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
701
- """
691
+
692
+ Raises:
693
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
694
+
695
+ Args:
696
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
697
+ Snowpark or Pandas DataFrame.
698
+ output_cols_prefix: Prefix for the response columns
702
699
  Returns:
703
700
  Transformed dataset.
704
701
  """
705
- self.fit(dataset)
706
- assert self._sklearn_object is not None
707
- return self._sklearn_object.embedding_
702
+ self._infer_input_output_cols(dataset)
703
+ super()._check_dataset_type(dataset)
704
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
705
+ estimator=self._sklearn_object,
706
+ dataset=dataset,
707
+ input_cols=self.input_cols,
708
+ label_cols=self.label_cols,
709
+ sample_weight_col=self.sample_weight_col,
710
+ autogenerated=self._autogenerated,
711
+ subproject=_SUBPROJECT,
712
+ )
713
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
714
+ drop_input_cols=self._drop_input_cols,
715
+ expected_output_cols_list=self.output_cols,
716
+ )
717
+ self._sklearn_object = fitted_estimator
718
+ self._is_fitted = True
719
+ return output_result
708
720
 
709
721
 
710
722
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -795,10 +807,8 @@ class GradientBoostingRegressor(BaseTransformer):
795
807
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
796
808
 
797
809
  if isinstance(dataset, DataFrame):
798
- self._deps = self._batch_inference_validate_snowpark(
799
- dataset=dataset,
800
- inference_method=inference_method,
801
- )
810
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
811
+ self._deps = self._get_dependencies()
802
812
  assert isinstance(
803
813
  dataset._session, Session
804
814
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -863,10 +873,8 @@ class GradientBoostingRegressor(BaseTransformer):
863
873
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
864
874
 
865
875
  if isinstance(dataset, DataFrame):
866
- self._deps = self._batch_inference_validate_snowpark(
867
- dataset=dataset,
868
- inference_method=inference_method,
869
- )
876
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
877
+ self._deps = self._get_dependencies()
870
878
  assert isinstance(
871
879
  dataset._session, Session
872
880
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -928,10 +936,8 @@ class GradientBoostingRegressor(BaseTransformer):
928
936
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
929
937
 
930
938
  if isinstance(dataset, DataFrame):
931
- self._deps = self._batch_inference_validate_snowpark(
932
- dataset=dataset,
933
- inference_method=inference_method,
934
- )
939
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
940
+ self._deps = self._get_dependencies()
935
941
  assert isinstance(
936
942
  dataset._session, Session
937
943
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -997,10 +1003,8 @@ class GradientBoostingRegressor(BaseTransformer):
997
1003
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
998
1004
 
999
1005
  if isinstance(dataset, DataFrame):
1000
- self._deps = self._batch_inference_validate_snowpark(
1001
- dataset=dataset,
1002
- inference_method=inference_method,
1003
- )
1006
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1007
+ self._deps = self._get_dependencies()
1004
1008
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1005
1009
  transform_kwargs = dict(
1006
1010
  session=dataset._session,
@@ -1064,17 +1068,15 @@ class GradientBoostingRegressor(BaseTransformer):
1064
1068
  transform_kwargs: ScoreKwargsTypedDict = dict()
1065
1069
 
1066
1070
  if isinstance(dataset, DataFrame):
1067
- self._deps = self._batch_inference_validate_snowpark(
1068
- dataset=dataset,
1069
- inference_method="score",
1070
- )
1071
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1072
+ self._deps = self._get_dependencies()
1071
1073
  selected_cols = self._get_active_columns()
1072
1074
  if len(selected_cols) > 0:
1073
1075
  dataset = dataset.select(selected_cols)
1074
1076
  assert isinstance(dataset._session, Session) # keep mypy happy
1075
1077
  transform_kwargs = dict(
1076
1078
  session=dataset._session,
1077
- dependencies=["snowflake-snowpark-python"] + self._deps,
1079
+ dependencies=self._deps,
1078
1080
  score_sproc_imports=['sklearn'],
1079
1081
  )
1080
1082
  elif isinstance(dataset, pd.DataFrame):
@@ -1139,11 +1141,8 @@ class GradientBoostingRegressor(BaseTransformer):
1139
1141
 
1140
1142
  if isinstance(dataset, DataFrame):
1141
1143
 
1142
- self._deps = self._batch_inference_validate_snowpark(
1143
- dataset=dataset,
1144
- inference_method=inference_method,
1145
-
1146
- )
1144
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1145
+ self._deps = self._get_dependencies()
1147
1146
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1148
1147
  transform_kwargs = dict(
1149
1148
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class HistGradientBoostingClassifier(BaseTransformer):
70
64
  r"""Histogram-based Gradient Boosting Classification Tree
71
65
  For more details on this class, see [sklearn.ensemble.HistGradientBoostingClassifier]
@@ -433,20 +427,17 @@ class HistGradientBoostingClassifier(BaseTransformer):
433
427
  self,
434
428
  dataset: DataFrame,
435
429
  inference_method: str,
436
- ) -> List[str]:
437
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
438
- return the available package that exists in the snowflake anaconda channel
430
+ ) -> None:
431
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
439
432
 
440
433
  Args:
441
434
  dataset: snowpark dataframe
442
435
  inference_method: the inference method such as predict, score...
443
-
436
+
444
437
  Raises:
445
438
  SnowflakeMLException: If the estimator is not fitted, raise error
446
439
  SnowflakeMLException: If the session is None, raise error
447
440
 
448
- Returns:
449
- A list of available package that exists in the snowflake anaconda channel
450
441
  """
451
442
  if not self._is_fitted:
452
443
  raise exceptions.SnowflakeMLException(
@@ -464,9 +455,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
464
455
  "Session must not specified for snowpark dataset."
465
456
  ),
466
457
  )
467
- # Validate that key package version in user workspace are supported in snowflake conda channel
468
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
469
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
458
+
470
459
 
471
460
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
472
461
  @telemetry.send_api_usage_telemetry(
@@ -514,7 +503,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
514
503
 
515
504
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
516
505
 
517
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
506
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
507
+ self._deps = self._get_dependencies()
518
508
  assert isinstance(
519
509
  dataset._session, Session
520
510
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -597,10 +587,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
597
587
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
598
588
  expected_dtype = convert_sp_to_sf_type(output_types[0])
599
589
 
600
- self._deps = self._batch_inference_validate_snowpark(
601
- dataset=dataset,
602
- inference_method=inference_method,
603
- )
590
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
591
+ self._deps = self._get_dependencies()
604
592
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
605
593
 
606
594
  transform_kwargs = dict(
@@ -667,16 +655,40 @@ class HistGradientBoostingClassifier(BaseTransformer):
667
655
  self._is_fitted = True
668
656
  return output_result
669
657
 
658
+
659
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
660
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
661
+ """ Method not supported for this class.
670
662
 
671
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
672
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
673
- """
663
+
664
+ Raises:
665
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
666
+
667
+ Args:
668
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
669
+ Snowpark or Pandas DataFrame.
670
+ output_cols_prefix: Prefix for the response columns
674
671
  Returns:
675
672
  Transformed dataset.
676
673
  """
677
- self.fit(dataset)
678
- assert self._sklearn_object is not None
679
- return self._sklearn_object.embedding_
674
+ self._infer_input_output_cols(dataset)
675
+ super()._check_dataset_type(dataset)
676
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
677
+ estimator=self._sklearn_object,
678
+ dataset=dataset,
679
+ input_cols=self.input_cols,
680
+ label_cols=self.label_cols,
681
+ sample_weight_col=self.sample_weight_col,
682
+ autogenerated=self._autogenerated,
683
+ subproject=_SUBPROJECT,
684
+ )
685
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
686
+ drop_input_cols=self._drop_input_cols,
687
+ expected_output_cols_list=self.output_cols,
688
+ )
689
+ self._sklearn_object = fitted_estimator
690
+ self._is_fitted = True
691
+ return output_result
680
692
 
681
693
 
682
694
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -769,10 +781,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
769
781
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
770
782
 
771
783
  if isinstance(dataset, DataFrame):
772
- self._deps = self._batch_inference_validate_snowpark(
773
- dataset=dataset,
774
- inference_method=inference_method,
775
- )
784
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
785
+ self._deps = self._get_dependencies()
776
786
  assert isinstance(
777
787
  dataset._session, Session
778
788
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -839,10 +849,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
839
849
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
840
850
 
841
851
  if isinstance(dataset, DataFrame):
842
- self._deps = self._batch_inference_validate_snowpark(
843
- dataset=dataset,
844
- inference_method=inference_method,
845
- )
852
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
853
+ self._deps = self._get_dependencies()
846
854
  assert isinstance(
847
855
  dataset._session, Session
848
856
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -906,10 +914,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
906
914
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
907
915
 
908
916
  if isinstance(dataset, DataFrame):
909
- self._deps = self._batch_inference_validate_snowpark(
910
- dataset=dataset,
911
- inference_method=inference_method,
912
- )
917
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
918
+ self._deps = self._get_dependencies()
913
919
  assert isinstance(
914
920
  dataset._session, Session
915
921
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -975,10 +981,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
975
981
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
976
982
 
977
983
  if isinstance(dataset, DataFrame):
978
- self._deps = self._batch_inference_validate_snowpark(
979
- dataset=dataset,
980
- inference_method=inference_method,
981
- )
984
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
985
+ self._deps = self._get_dependencies()
982
986
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
983
987
  transform_kwargs = dict(
984
988
  session=dataset._session,
@@ -1042,17 +1046,15 @@ class HistGradientBoostingClassifier(BaseTransformer):
1042
1046
  transform_kwargs: ScoreKwargsTypedDict = dict()
1043
1047
 
1044
1048
  if isinstance(dataset, DataFrame):
1045
- self._deps = self._batch_inference_validate_snowpark(
1046
- dataset=dataset,
1047
- inference_method="score",
1048
- )
1049
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1050
+ self._deps = self._get_dependencies()
1049
1051
  selected_cols = self._get_active_columns()
1050
1052
  if len(selected_cols) > 0:
1051
1053
  dataset = dataset.select(selected_cols)
1052
1054
  assert isinstance(dataset._session, Session) # keep mypy happy
1053
1055
  transform_kwargs = dict(
1054
1056
  session=dataset._session,
1055
- dependencies=["snowflake-snowpark-python"] + self._deps,
1057
+ dependencies=self._deps,
1056
1058
  score_sproc_imports=['sklearn'],
1057
1059
  )
1058
1060
  elif isinstance(dataset, pd.DataFrame):
@@ -1117,11 +1119,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
1117
1119
 
1118
1120
  if isinstance(dataset, DataFrame):
1119
1121
 
1120
- self._deps = self._batch_inference_validate_snowpark(
1121
- dataset=dataset,
1122
- inference_method=inference_method,
1123
-
1124
- )
1122
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1123
+ self._deps = self._get_dependencies()
1125
1124
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1126
1125
  transform_kwargs = dict(
1127
1126
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class HistGradientBoostingRegressor(BaseTransformer):
70
64
  r"""Histogram-based Gradient Boosting Regression Tree
71
65
  For more details on this class, see [sklearn.ensemble.HistGradientBoostingRegressor]
@@ -424,20 +418,17 @@ class HistGradientBoostingRegressor(BaseTransformer):
424
418
  self,
425
419
  dataset: DataFrame,
426
420
  inference_method: str,
427
- ) -> List[str]:
428
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
429
- return the available package that exists in the snowflake anaconda channel
421
+ ) -> None:
422
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
430
423
 
431
424
  Args:
432
425
  dataset: snowpark dataframe
433
426
  inference_method: the inference method such as predict, score...
434
-
427
+
435
428
  Raises:
436
429
  SnowflakeMLException: If the estimator is not fitted, raise error
437
430
  SnowflakeMLException: If the session is None, raise error
438
431
 
439
- Returns:
440
- A list of available package that exists in the snowflake anaconda channel
441
432
  """
442
433
  if not self._is_fitted:
443
434
  raise exceptions.SnowflakeMLException(
@@ -455,9 +446,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
455
446
  "Session must not specified for snowpark dataset."
456
447
  ),
457
448
  )
458
- # Validate that key package version in user workspace are supported in snowflake conda channel
459
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
460
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
449
+
461
450
 
462
451
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
463
452
  @telemetry.send_api_usage_telemetry(
@@ -505,7 +494,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
505
494
 
506
495
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
507
496
 
508
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
497
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
498
+ self._deps = self._get_dependencies()
509
499
  assert isinstance(
510
500
  dataset._session, Session
511
501
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -588,10 +578,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
588
578
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
589
579
  expected_dtype = convert_sp_to_sf_type(output_types[0])
590
580
 
591
- self._deps = self._batch_inference_validate_snowpark(
592
- dataset=dataset,
593
- inference_method=inference_method,
594
- )
581
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
582
+ self._deps = self._get_dependencies()
595
583
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
596
584
 
597
585
  transform_kwargs = dict(
@@ -658,16 +646,40 @@ class HistGradientBoostingRegressor(BaseTransformer):
658
646
  self._is_fitted = True
659
647
  return output_result
660
648
 
649
+
650
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
651
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
652
+ """ Method not supported for this class.
661
653
 
662
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
663
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
664
- """
654
+
655
+ Raises:
656
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
657
+
658
+ Args:
659
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
660
+ Snowpark or Pandas DataFrame.
661
+ output_cols_prefix: Prefix for the response columns
665
662
  Returns:
666
663
  Transformed dataset.
667
664
  """
668
- self.fit(dataset)
669
- assert self._sklearn_object is not None
670
- return self._sklearn_object.embedding_
665
+ self._infer_input_output_cols(dataset)
666
+ super()._check_dataset_type(dataset)
667
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
668
+ estimator=self._sklearn_object,
669
+ dataset=dataset,
670
+ input_cols=self.input_cols,
671
+ label_cols=self.label_cols,
672
+ sample_weight_col=self.sample_weight_col,
673
+ autogenerated=self._autogenerated,
674
+ subproject=_SUBPROJECT,
675
+ )
676
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
677
+ drop_input_cols=self._drop_input_cols,
678
+ expected_output_cols_list=self.output_cols,
679
+ )
680
+ self._sklearn_object = fitted_estimator
681
+ self._is_fitted = True
682
+ return output_result
671
683
 
672
684
 
673
685
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -758,10 +770,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
758
770
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
759
771
 
760
772
  if isinstance(dataset, DataFrame):
761
- self._deps = self._batch_inference_validate_snowpark(
762
- dataset=dataset,
763
- inference_method=inference_method,
764
- )
773
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
774
+ self._deps = self._get_dependencies()
765
775
  assert isinstance(
766
776
  dataset._session, Session
767
777
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -826,10 +836,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
826
836
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
827
837
 
828
838
  if isinstance(dataset, DataFrame):
829
- self._deps = self._batch_inference_validate_snowpark(
830
- dataset=dataset,
831
- inference_method=inference_method,
832
- )
839
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
840
+ self._deps = self._get_dependencies()
833
841
  assert isinstance(
834
842
  dataset._session, Session
835
843
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -891,10 +899,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
891
899
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
892
900
 
893
901
  if isinstance(dataset, DataFrame):
894
- self._deps = self._batch_inference_validate_snowpark(
895
- dataset=dataset,
896
- inference_method=inference_method,
897
- )
902
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
903
+ self._deps = self._get_dependencies()
898
904
  assert isinstance(
899
905
  dataset._session, Session
900
906
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -960,10 +966,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
960
966
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
961
967
 
962
968
  if isinstance(dataset, DataFrame):
963
- self._deps = self._batch_inference_validate_snowpark(
964
- dataset=dataset,
965
- inference_method=inference_method,
966
- )
969
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
970
+ self._deps = self._get_dependencies()
967
971
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
968
972
  transform_kwargs = dict(
969
973
  session=dataset._session,
@@ -1027,17 +1031,15 @@ class HistGradientBoostingRegressor(BaseTransformer):
1027
1031
  transform_kwargs: ScoreKwargsTypedDict = dict()
1028
1032
 
1029
1033
  if isinstance(dataset, DataFrame):
1030
- self._deps = self._batch_inference_validate_snowpark(
1031
- dataset=dataset,
1032
- inference_method="score",
1033
- )
1034
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1035
+ self._deps = self._get_dependencies()
1034
1036
  selected_cols = self._get_active_columns()
1035
1037
  if len(selected_cols) > 0:
1036
1038
  dataset = dataset.select(selected_cols)
1037
1039
  assert isinstance(dataset._session, Session) # keep mypy happy
1038
1040
  transform_kwargs = dict(
1039
1041
  session=dataset._session,
1040
- dependencies=["snowflake-snowpark-python"] + self._deps,
1042
+ dependencies=self._deps,
1041
1043
  score_sproc_imports=['sklearn'],
1042
1044
  )
1043
1045
  elif isinstance(dataset, pd.DataFrame):
@@ -1102,11 +1104,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
1102
1104
 
1103
1105
  if isinstance(dataset, DataFrame):
1104
1106
 
1105
- self._deps = self._batch_inference_validate_snowpark(
1106
- dataset=dataset,
1107
- inference_method=inference_method,
1108
-
1109
- )
1107
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1108
+ self._deps = self._get_dependencies()
1110
1109
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1111
1110
  transform_kwargs = dict(
1112
1111
  session = dataset._session,