snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
59
59
 
60
60
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
61
61
 
62
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
- return check
66
-
67
-
68
62
  class XGBRFClassifier(BaseTransformer):
69
63
  r"""scikit-learn API for XGBoost random forest classification
70
64
  For more details on this class, see [xgboost.XGBRFClassifier]
@@ -487,20 +481,17 @@ class XGBRFClassifier(BaseTransformer):
487
481
  self,
488
482
  dataset: DataFrame,
489
483
  inference_method: str,
490
- ) -> List[str]:
491
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
492
- return the available package that exists in the snowflake anaconda channel
484
+ ) -> None:
485
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
493
486
 
494
487
  Args:
495
488
  dataset: snowpark dataframe
496
489
  inference_method: the inference method such as predict, score...
497
-
490
+
498
491
  Raises:
499
492
  SnowflakeMLException: If the estimator is not fitted, raise error
500
493
  SnowflakeMLException: If the session is None, raise error
501
494
 
502
- Returns:
503
- A list of available package that exists in the snowflake anaconda channel
504
495
  """
505
496
  if not self._is_fitted:
506
497
  raise exceptions.SnowflakeMLException(
@@ -518,9 +509,7 @@ class XGBRFClassifier(BaseTransformer):
518
509
  "Session must not specified for snowpark dataset."
519
510
  ),
520
511
  )
521
- # Validate that key package version in user workspace are supported in snowflake conda channel
522
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
523
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
512
+
524
513
 
525
514
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
526
515
  @telemetry.send_api_usage_telemetry(
@@ -568,7 +557,8 @@ class XGBRFClassifier(BaseTransformer):
568
557
 
569
558
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
570
559
 
571
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
560
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
561
+ self._deps = self._get_dependencies()
572
562
  assert isinstance(
573
563
  dataset._session, Session
574
564
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -651,10 +641,8 @@ class XGBRFClassifier(BaseTransformer):
651
641
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
652
642
  expected_dtype = convert_sp_to_sf_type(output_types[0])
653
643
 
654
- self._deps = self._batch_inference_validate_snowpark(
655
- dataset=dataset,
656
- inference_method=inference_method,
657
- )
644
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
645
+ self._deps = self._get_dependencies()
658
646
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
659
647
 
660
648
  transform_kwargs = dict(
@@ -721,16 +709,40 @@ class XGBRFClassifier(BaseTransformer):
721
709
  self._is_fitted = True
722
710
  return output_result
723
711
 
712
+
713
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
714
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
715
+ """ Method not supported for this class.
724
716
 
725
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
726
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
727
- """
717
+
718
+ Raises:
719
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
720
+
721
+ Args:
722
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
723
+ Snowpark or Pandas DataFrame.
724
+ output_cols_prefix: Prefix for the response columns
728
725
  Returns:
729
726
  Transformed dataset.
730
727
  """
731
- self.fit(dataset)
732
- assert self._sklearn_object is not None
733
- return self._sklearn_object.embedding_
728
+ self._infer_input_output_cols(dataset)
729
+ super()._check_dataset_type(dataset)
730
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
731
+ estimator=self._sklearn_object,
732
+ dataset=dataset,
733
+ input_cols=self.input_cols,
734
+ label_cols=self.label_cols,
735
+ sample_weight_col=self.sample_weight_col,
736
+ autogenerated=self._autogenerated,
737
+ subproject=_SUBPROJECT,
738
+ )
739
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
740
+ drop_input_cols=self._drop_input_cols,
741
+ expected_output_cols_list=self.output_cols,
742
+ )
743
+ self._sklearn_object = fitted_estimator
744
+ self._is_fitted = True
745
+ return output_result
734
746
 
735
747
 
736
748
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -823,10 +835,8 @@ class XGBRFClassifier(BaseTransformer):
823
835
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
824
836
 
825
837
  if isinstance(dataset, DataFrame):
826
- self._deps = self._batch_inference_validate_snowpark(
827
- dataset=dataset,
828
- inference_method=inference_method,
829
- )
838
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
839
+ self._deps = self._get_dependencies()
830
840
  assert isinstance(
831
841
  dataset._session, Session
832
842
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -893,10 +903,8 @@ class XGBRFClassifier(BaseTransformer):
893
903
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
894
904
 
895
905
  if isinstance(dataset, DataFrame):
896
- self._deps = self._batch_inference_validate_snowpark(
897
- dataset=dataset,
898
- inference_method=inference_method,
899
- )
906
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
907
+ self._deps = self._get_dependencies()
900
908
  assert isinstance(
901
909
  dataset._session, Session
902
910
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -958,10 +966,8 @@ class XGBRFClassifier(BaseTransformer):
958
966
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
959
967
 
960
968
  if isinstance(dataset, DataFrame):
961
- self._deps = self._batch_inference_validate_snowpark(
962
- dataset=dataset,
963
- inference_method=inference_method,
964
- )
969
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
970
+ self._deps = self._get_dependencies()
965
971
  assert isinstance(
966
972
  dataset._session, Session
967
973
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -1027,10 +1033,8 @@ class XGBRFClassifier(BaseTransformer):
1027
1033
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
1028
1034
 
1029
1035
  if isinstance(dataset, DataFrame):
1030
- self._deps = self._batch_inference_validate_snowpark(
1031
- dataset=dataset,
1032
- inference_method=inference_method,
1033
- )
1036
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1037
+ self._deps = self._get_dependencies()
1034
1038
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1035
1039
  transform_kwargs = dict(
1036
1040
  session=dataset._session,
@@ -1094,17 +1098,15 @@ class XGBRFClassifier(BaseTransformer):
1094
1098
  transform_kwargs: ScoreKwargsTypedDict = dict()
1095
1099
 
1096
1100
  if isinstance(dataset, DataFrame):
1097
- self._deps = self._batch_inference_validate_snowpark(
1098
- dataset=dataset,
1099
- inference_method="score",
1100
- )
1101
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1102
+ self._deps = self._get_dependencies()
1101
1103
  selected_cols = self._get_active_columns()
1102
1104
  if len(selected_cols) > 0:
1103
1105
  dataset = dataset.select(selected_cols)
1104
1106
  assert isinstance(dataset._session, Session) # keep mypy happy
1105
1107
  transform_kwargs = dict(
1106
1108
  session=dataset._session,
1107
- dependencies=["snowflake-snowpark-python"] + self._deps,
1109
+ dependencies=self._deps,
1108
1110
  score_sproc_imports=['xgboost'],
1109
1111
  )
1110
1112
  elif isinstance(dataset, pd.DataFrame):
@@ -1169,11 +1171,8 @@ class XGBRFClassifier(BaseTransformer):
1169
1171
 
1170
1172
  if isinstance(dataset, DataFrame):
1171
1173
 
1172
- self._deps = self._batch_inference_validate_snowpark(
1173
- dataset=dataset,
1174
- inference_method=inference_method,
1175
-
1176
- )
1174
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1175
+ self._deps = self._get_dependencies()
1177
1176
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1178
1177
  transform_kwargs = dict(
1179
1178
  session = dataset._session,
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
59
59
 
60
60
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
61
61
 
62
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
- return check
66
-
67
-
68
62
  class XGBRFRegressor(BaseTransformer):
69
63
  r"""scikit-learn API for XGBoost random forest regression
70
64
  For more details on this class, see [xgboost.XGBRFRegressor]
@@ -487,20 +481,17 @@ class XGBRFRegressor(BaseTransformer):
487
481
  self,
488
482
  dataset: DataFrame,
489
483
  inference_method: str,
490
- ) -> List[str]:
491
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
492
- return the available package that exists in the snowflake anaconda channel
484
+ ) -> None:
485
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
493
486
 
494
487
  Args:
495
488
  dataset: snowpark dataframe
496
489
  inference_method: the inference method such as predict, score...
497
-
490
+
498
491
  Raises:
499
492
  SnowflakeMLException: If the estimator is not fitted, raise error
500
493
  SnowflakeMLException: If the session is None, raise error
501
494
 
502
- Returns:
503
- A list of available package that exists in the snowflake anaconda channel
504
495
  """
505
496
  if not self._is_fitted:
506
497
  raise exceptions.SnowflakeMLException(
@@ -518,9 +509,7 @@ class XGBRFRegressor(BaseTransformer):
518
509
  "Session must not specified for snowpark dataset."
519
510
  ),
520
511
  )
521
- # Validate that key package version in user workspace are supported in snowflake conda channel
522
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
523
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
512
+
524
513
 
525
514
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
526
515
  @telemetry.send_api_usage_telemetry(
@@ -568,7 +557,8 @@ class XGBRFRegressor(BaseTransformer):
568
557
 
569
558
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
570
559
 
571
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
560
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
561
+ self._deps = self._get_dependencies()
572
562
  assert isinstance(
573
563
  dataset._session, Session
574
564
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -651,10 +641,8 @@ class XGBRFRegressor(BaseTransformer):
651
641
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
652
642
  expected_dtype = convert_sp_to_sf_type(output_types[0])
653
643
 
654
- self._deps = self._batch_inference_validate_snowpark(
655
- dataset=dataset,
656
- inference_method=inference_method,
657
- )
644
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
645
+ self._deps = self._get_dependencies()
658
646
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
659
647
 
660
648
  transform_kwargs = dict(
@@ -721,16 +709,40 @@ class XGBRFRegressor(BaseTransformer):
721
709
  self._is_fitted = True
722
710
  return output_result
723
711
 
712
+
713
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
714
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
715
+ """ Method not supported for this class.
724
716
 
725
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
726
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
727
- """
717
+
718
+ Raises:
719
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
720
+
721
+ Args:
722
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
723
+ Snowpark or Pandas DataFrame.
724
+ output_cols_prefix: Prefix for the response columns
728
725
  Returns:
729
726
  Transformed dataset.
730
727
  """
731
- self.fit(dataset)
732
- assert self._sklearn_object is not None
733
- return self._sklearn_object.embedding_
728
+ self._infer_input_output_cols(dataset)
729
+ super()._check_dataset_type(dataset)
730
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
731
+ estimator=self._sklearn_object,
732
+ dataset=dataset,
733
+ input_cols=self.input_cols,
734
+ label_cols=self.label_cols,
735
+ sample_weight_col=self.sample_weight_col,
736
+ autogenerated=self._autogenerated,
737
+ subproject=_SUBPROJECT,
738
+ )
739
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
740
+ drop_input_cols=self._drop_input_cols,
741
+ expected_output_cols_list=self.output_cols,
742
+ )
743
+ self._sklearn_object = fitted_estimator
744
+ self._is_fitted = True
745
+ return output_result
734
746
 
735
747
 
736
748
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -821,10 +833,8 @@ class XGBRFRegressor(BaseTransformer):
821
833
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
822
834
 
823
835
  if isinstance(dataset, DataFrame):
824
- self._deps = self._batch_inference_validate_snowpark(
825
- dataset=dataset,
826
- inference_method=inference_method,
827
- )
836
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
837
+ self._deps = self._get_dependencies()
828
838
  assert isinstance(
829
839
  dataset._session, Session
830
840
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -889,10 +899,8 @@ class XGBRFRegressor(BaseTransformer):
889
899
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
890
900
 
891
901
  if isinstance(dataset, DataFrame):
892
- self._deps = self._batch_inference_validate_snowpark(
893
- dataset=dataset,
894
- inference_method=inference_method,
895
- )
902
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
903
+ self._deps = self._get_dependencies()
896
904
  assert isinstance(
897
905
  dataset._session, Session
898
906
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -954,10 +962,8 @@ class XGBRFRegressor(BaseTransformer):
954
962
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
955
963
 
956
964
  if isinstance(dataset, DataFrame):
957
- self._deps = self._batch_inference_validate_snowpark(
958
- dataset=dataset,
959
- inference_method=inference_method,
960
- )
965
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
966
+ self._deps = self._get_dependencies()
961
967
  assert isinstance(
962
968
  dataset._session, Session
963
969
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -1023,10 +1029,8 @@ class XGBRFRegressor(BaseTransformer):
1023
1029
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
1024
1030
 
1025
1031
  if isinstance(dataset, DataFrame):
1026
- self._deps = self._batch_inference_validate_snowpark(
1027
- dataset=dataset,
1028
- inference_method=inference_method,
1029
- )
1032
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1033
+ self._deps = self._get_dependencies()
1030
1034
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1031
1035
  transform_kwargs = dict(
1032
1036
  session=dataset._session,
@@ -1090,17 +1094,15 @@ class XGBRFRegressor(BaseTransformer):
1090
1094
  transform_kwargs: ScoreKwargsTypedDict = dict()
1091
1095
 
1092
1096
  if isinstance(dataset, DataFrame):
1093
- self._deps = self._batch_inference_validate_snowpark(
1094
- dataset=dataset,
1095
- inference_method="score",
1096
- )
1097
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1098
+ self._deps = self._get_dependencies()
1097
1099
  selected_cols = self._get_active_columns()
1098
1100
  if len(selected_cols) > 0:
1099
1101
  dataset = dataset.select(selected_cols)
1100
1102
  assert isinstance(dataset._session, Session) # keep mypy happy
1101
1103
  transform_kwargs = dict(
1102
1104
  session=dataset._session,
1103
- dependencies=["snowflake-snowpark-python"] + self._deps,
1105
+ dependencies=self._deps,
1104
1106
  score_sproc_imports=['xgboost'],
1105
1107
  )
1106
1108
  elif isinstance(dataset, pd.DataFrame):
@@ -1165,11 +1167,8 @@ class XGBRFRegressor(BaseTransformer):
1165
1167
 
1166
1168
  if isinstance(dataset, DataFrame):
1167
1169
 
1168
- self._deps = self._batch_inference_validate_snowpark(
1169
- dataset=dataset,
1170
- inference_method=inference_method,
1171
-
1172
- )
1170
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1171
+ self._deps = self._get_dependencies()
1173
1172
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1174
1173
  transform_kwargs = dict(
1175
1174
  session = dataset._session,