snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
59
59
 
60
60
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
61
61
 
62
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
- return check
66
-
67
-
68
62
  class XGBRFClassifier(BaseTransformer):
69
63
  r"""scikit-learn API for XGBoost random forest classification
70
64
  For more details on this class, see [xgboost.XGBRFClassifier]
@@ -487,20 +481,17 @@ class XGBRFClassifier(BaseTransformer):
487
481
  self,
488
482
  dataset: DataFrame,
489
483
  inference_method: str,
490
- ) -> List[str]:
491
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
492
- return the available package that exists in the snowflake anaconda channel
484
+ ) -> None:
485
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
493
486
 
494
487
  Args:
495
488
  dataset: snowpark dataframe
496
489
  inference_method: the inference method such as predict, score...
497
-
490
+
498
491
  Raises:
499
492
  SnowflakeMLException: If the estimator is not fitted, raise error
500
493
  SnowflakeMLException: If the session is None, raise error
501
494
 
502
- Returns:
503
- A list of available package that exists in the snowflake anaconda channel
504
495
  """
505
496
  if not self._is_fitted:
506
497
  raise exceptions.SnowflakeMLException(
@@ -518,9 +509,7 @@ class XGBRFClassifier(BaseTransformer):
518
509
  "Session must not specified for snowpark dataset."
519
510
  ),
520
511
  )
521
- # Validate that key package version in user workspace are supported in snowflake conda channel
522
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
523
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
512
+
524
513
 
525
514
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
526
515
  @telemetry.send_api_usage_telemetry(
@@ -568,7 +557,8 @@ class XGBRFClassifier(BaseTransformer):
568
557
 
569
558
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
570
559
 
571
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
560
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
561
+ self._deps = self._get_dependencies()
572
562
  assert isinstance(
573
563
  dataset._session, Session
574
564
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -651,10 +641,8 @@ class XGBRFClassifier(BaseTransformer):
651
641
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
652
642
  expected_dtype = convert_sp_to_sf_type(output_types[0])
653
643
 
654
- self._deps = self._batch_inference_validate_snowpark(
655
- dataset=dataset,
656
- inference_method=inference_method,
657
- )
644
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
645
+ self._deps = self._get_dependencies()
658
646
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
659
647
 
660
648
  transform_kwargs = dict(
@@ -721,16 +709,40 @@ class XGBRFClassifier(BaseTransformer):
721
709
  self._is_fitted = True
722
710
  return output_result
723
711
 
712
+
713
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
714
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
715
+ """ Method not supported for this class.
724
716
 
725
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
726
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
727
- """
717
+
718
+ Raises:
719
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
720
+
721
+ Args:
722
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
723
+ Snowpark or Pandas DataFrame.
724
+ output_cols_prefix: Prefix for the response columns
728
725
  Returns:
729
726
  Transformed dataset.
730
727
  """
731
- self.fit(dataset)
732
- assert self._sklearn_object is not None
733
- return self._sklearn_object.embedding_
728
+ self._infer_input_output_cols(dataset)
729
+ super()._check_dataset_type(dataset)
730
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
731
+ estimator=self._sklearn_object,
732
+ dataset=dataset,
733
+ input_cols=self.input_cols,
734
+ label_cols=self.label_cols,
735
+ sample_weight_col=self.sample_weight_col,
736
+ autogenerated=self._autogenerated,
737
+ subproject=_SUBPROJECT,
738
+ )
739
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
740
+ drop_input_cols=self._drop_input_cols,
741
+ expected_output_cols_list=self.output_cols,
742
+ )
743
+ self._sklearn_object = fitted_estimator
744
+ self._is_fitted = True
745
+ return output_result
734
746
 
735
747
 
736
748
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -823,10 +835,8 @@ class XGBRFClassifier(BaseTransformer):
823
835
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
824
836
 
825
837
  if isinstance(dataset, DataFrame):
826
- self._deps = self._batch_inference_validate_snowpark(
827
- dataset=dataset,
828
- inference_method=inference_method,
829
- )
838
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
839
+ self._deps = self._get_dependencies()
830
840
  assert isinstance(
831
841
  dataset._session, Session
832
842
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -893,10 +903,8 @@ class XGBRFClassifier(BaseTransformer):
893
903
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
894
904
 
895
905
  if isinstance(dataset, DataFrame):
896
- self._deps = self._batch_inference_validate_snowpark(
897
- dataset=dataset,
898
- inference_method=inference_method,
899
- )
906
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
907
+ self._deps = self._get_dependencies()
900
908
  assert isinstance(
901
909
  dataset._session, Session
902
910
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -958,10 +966,8 @@ class XGBRFClassifier(BaseTransformer):
958
966
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
959
967
 
960
968
  if isinstance(dataset, DataFrame):
961
- self._deps = self._batch_inference_validate_snowpark(
962
- dataset=dataset,
963
- inference_method=inference_method,
964
- )
969
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
970
+ self._deps = self._get_dependencies()
965
971
  assert isinstance(
966
972
  dataset._session, Session
967
973
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -1027,10 +1033,8 @@ class XGBRFClassifier(BaseTransformer):
1027
1033
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
1028
1034
 
1029
1035
  if isinstance(dataset, DataFrame):
1030
- self._deps = self._batch_inference_validate_snowpark(
1031
- dataset=dataset,
1032
- inference_method=inference_method,
1033
- )
1036
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1037
+ self._deps = self._get_dependencies()
1034
1038
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1035
1039
  transform_kwargs = dict(
1036
1040
  session=dataset._session,
@@ -1094,17 +1098,15 @@ class XGBRFClassifier(BaseTransformer):
1094
1098
  transform_kwargs: ScoreKwargsTypedDict = dict()
1095
1099
 
1096
1100
  if isinstance(dataset, DataFrame):
1097
- self._deps = self._batch_inference_validate_snowpark(
1098
- dataset=dataset,
1099
- inference_method="score",
1100
- )
1101
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1102
+ self._deps = self._get_dependencies()
1101
1103
  selected_cols = self._get_active_columns()
1102
1104
  if len(selected_cols) > 0:
1103
1105
  dataset = dataset.select(selected_cols)
1104
1106
  assert isinstance(dataset._session, Session) # keep mypy happy
1105
1107
  transform_kwargs = dict(
1106
1108
  session=dataset._session,
1107
- dependencies=["snowflake-snowpark-python"] + self._deps,
1109
+ dependencies=self._deps,
1108
1110
  score_sproc_imports=['xgboost'],
1109
1111
  )
1110
1112
  elif isinstance(dataset, pd.DataFrame):
@@ -1169,11 +1171,8 @@ class XGBRFClassifier(BaseTransformer):
1169
1171
 
1170
1172
  if isinstance(dataset, DataFrame):
1171
1173
 
1172
- self._deps = self._batch_inference_validate_snowpark(
1173
- dataset=dataset,
1174
- inference_method=inference_method,
1175
-
1176
- )
1174
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1175
+ self._deps = self._get_dependencies()
1177
1176
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1178
1177
  transform_kwargs = dict(
1179
1178
  session = dataset._session,
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
59
59
 
60
60
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
61
61
 
62
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
- return check
66
-
67
-
68
62
  class XGBRFRegressor(BaseTransformer):
69
63
  r"""scikit-learn API for XGBoost random forest regression
70
64
  For more details on this class, see [xgboost.XGBRFRegressor]
@@ -487,20 +481,17 @@ class XGBRFRegressor(BaseTransformer):
487
481
  self,
488
482
  dataset: DataFrame,
489
483
  inference_method: str,
490
- ) -> List[str]:
491
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
492
- return the available package that exists in the snowflake anaconda channel
484
+ ) -> None:
485
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
493
486
 
494
487
  Args:
495
488
  dataset: snowpark dataframe
496
489
  inference_method: the inference method such as predict, score...
497
-
490
+
498
491
  Raises:
499
492
  SnowflakeMLException: If the estimator is not fitted, raise error
500
493
  SnowflakeMLException: If the session is None, raise error
501
494
 
502
- Returns:
503
- A list of available package that exists in the snowflake anaconda channel
504
495
  """
505
496
  if not self._is_fitted:
506
497
  raise exceptions.SnowflakeMLException(
@@ -518,9 +509,7 @@ class XGBRFRegressor(BaseTransformer):
518
509
  "Session must not specified for snowpark dataset."
519
510
  ),
520
511
  )
521
- # Validate that key package version in user workspace are supported in snowflake conda channel
522
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
523
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
512
+
524
513
 
525
514
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
526
515
  @telemetry.send_api_usage_telemetry(
@@ -568,7 +557,8 @@ class XGBRFRegressor(BaseTransformer):
568
557
 
569
558
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
570
559
 
571
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
560
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
561
+ self._deps = self._get_dependencies()
572
562
  assert isinstance(
573
563
  dataset._session, Session
574
564
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -651,10 +641,8 @@ class XGBRFRegressor(BaseTransformer):
651
641
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
652
642
  expected_dtype = convert_sp_to_sf_type(output_types[0])
653
643
 
654
- self._deps = self._batch_inference_validate_snowpark(
655
- dataset=dataset,
656
- inference_method=inference_method,
657
- )
644
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
645
+ self._deps = self._get_dependencies()
658
646
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
659
647
 
660
648
  transform_kwargs = dict(
@@ -721,16 +709,40 @@ class XGBRFRegressor(BaseTransformer):
721
709
  self._is_fitted = True
722
710
  return output_result
723
711
 
712
+
713
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
714
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
715
+ """ Method not supported for this class.
724
716
 
725
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
726
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
727
- """
717
+
718
+ Raises:
719
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
720
+
721
+ Args:
722
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
723
+ Snowpark or Pandas DataFrame.
724
+ output_cols_prefix: Prefix for the response columns
728
725
  Returns:
729
726
  Transformed dataset.
730
727
  """
731
- self.fit(dataset)
732
- assert self._sklearn_object is not None
733
- return self._sklearn_object.embedding_
728
+ self._infer_input_output_cols(dataset)
729
+ super()._check_dataset_type(dataset)
730
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
731
+ estimator=self._sklearn_object,
732
+ dataset=dataset,
733
+ input_cols=self.input_cols,
734
+ label_cols=self.label_cols,
735
+ sample_weight_col=self.sample_weight_col,
736
+ autogenerated=self._autogenerated,
737
+ subproject=_SUBPROJECT,
738
+ )
739
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
740
+ drop_input_cols=self._drop_input_cols,
741
+ expected_output_cols_list=self.output_cols,
742
+ )
743
+ self._sklearn_object = fitted_estimator
744
+ self._is_fitted = True
745
+ return output_result
734
746
 
735
747
 
736
748
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -821,10 +833,8 @@ class XGBRFRegressor(BaseTransformer):
821
833
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
822
834
 
823
835
  if isinstance(dataset, DataFrame):
824
- self._deps = self._batch_inference_validate_snowpark(
825
- dataset=dataset,
826
- inference_method=inference_method,
827
- )
836
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
837
+ self._deps = self._get_dependencies()
828
838
  assert isinstance(
829
839
  dataset._session, Session
830
840
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -889,10 +899,8 @@ class XGBRFRegressor(BaseTransformer):
889
899
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
890
900
 
891
901
  if isinstance(dataset, DataFrame):
892
- self._deps = self._batch_inference_validate_snowpark(
893
- dataset=dataset,
894
- inference_method=inference_method,
895
- )
902
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
903
+ self._deps = self._get_dependencies()
896
904
  assert isinstance(
897
905
  dataset._session, Session
898
906
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -954,10 +962,8 @@ class XGBRFRegressor(BaseTransformer):
954
962
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
955
963
 
956
964
  if isinstance(dataset, DataFrame):
957
- self._deps = self._batch_inference_validate_snowpark(
958
- dataset=dataset,
959
- inference_method=inference_method,
960
- )
965
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
966
+ self._deps = self._get_dependencies()
961
967
  assert isinstance(
962
968
  dataset._session, Session
963
969
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -1023,10 +1029,8 @@ class XGBRFRegressor(BaseTransformer):
1023
1029
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
1024
1030
 
1025
1031
  if isinstance(dataset, DataFrame):
1026
- self._deps = self._batch_inference_validate_snowpark(
1027
- dataset=dataset,
1028
- inference_method=inference_method,
1029
- )
1032
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1033
+ self._deps = self._get_dependencies()
1030
1034
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1031
1035
  transform_kwargs = dict(
1032
1036
  session=dataset._session,
@@ -1090,17 +1094,15 @@ class XGBRFRegressor(BaseTransformer):
1090
1094
  transform_kwargs: ScoreKwargsTypedDict = dict()
1091
1095
 
1092
1096
  if isinstance(dataset, DataFrame):
1093
- self._deps = self._batch_inference_validate_snowpark(
1094
- dataset=dataset,
1095
- inference_method="score",
1096
- )
1097
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1098
+ self._deps = self._get_dependencies()
1097
1099
  selected_cols = self._get_active_columns()
1098
1100
  if len(selected_cols) > 0:
1099
1101
  dataset = dataset.select(selected_cols)
1100
1102
  assert isinstance(dataset._session, Session) # keep mypy happy
1101
1103
  transform_kwargs = dict(
1102
1104
  session=dataset._session,
1103
- dependencies=["snowflake-snowpark-python"] + self._deps,
1105
+ dependencies=self._deps,
1104
1106
  score_sproc_imports=['xgboost'],
1105
1107
  )
1106
1108
  elif isinstance(dataset, pd.DataFrame):
@@ -1165,11 +1167,8 @@ class XGBRFRegressor(BaseTransformer):
1165
1167
 
1166
1168
  if isinstance(dataset, DataFrame):
1167
1169
 
1168
- self._deps = self._batch_inference_validate_snowpark(
1169
- dataset=dataset,
1170
- inference_method=inference_method,
1171
-
1172
- )
1170
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1171
+ self._deps = self._get_dependencies()
1173
1172
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1174
1173
  transform_kwargs = dict(
1175
1174
  session = dataset._session,
@@ -48,20 +48,29 @@ class ModelManager:
48
48
  options: Optional[model_types.ModelSaveOption] = None,
49
49
  statement_params: Optional[Dict[str, Any]] = None,
50
50
  ) -> model_version_impl.ModelVersion:
51
- model_name_id = sql_identifier.SqlIdentifier(model_name)
51
+ database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
52
52
 
53
53
  if not version_name:
54
54
  version_name = self._hrid_generator.generate()[1]
55
55
  version_name_id = sql_identifier.SqlIdentifier(version_name)
56
56
 
57
57
  if self._model_ops.validate_existence(
58
- model_name=model_name_id, statement_params=statement_params
58
+ database_name=database_name_id,
59
+ schema_name=schema_name_id,
60
+ model_name=model_name_id,
61
+ statement_params=statement_params,
59
62
  ) and self._model_ops.validate_existence(
60
- model_name=model_name_id, version_name=version_name_id, statement_params=statement_params
63
+ database_name=database_name_id,
64
+ schema_name=schema_name_id,
65
+ model_name=model_name_id,
66
+ version_name=version_name_id,
67
+ statement_params=statement_params,
61
68
  ):
62
69
  raise ValueError(f"Model {model_name} version {version_name} already existed.")
63
70
 
64
71
  stage_path = self._model_ops.prepare_model_stage_path(
72
+ database_name=database_name_id,
73
+ schema_name=schema_name_id,
65
74
  statement_params=statement_params,
66
75
  )
67
76
 
@@ -85,13 +94,19 @@ class ModelManager:
85
94
 
86
95
  self._model_ops.create_from_stage(
87
96
  composed_model=mc,
97
+ database_name=database_name_id,
98
+ schema_name=schema_name_id,
88
99
  model_name=model_name_id,
89
100
  version_name=version_name_id,
90
101
  statement_params=statement_params,
91
102
  )
92
103
 
93
104
  mv = model_version_impl.ModelVersion._ref(
94
- self._model_ops,
105
+ model_ops.ModelOperator(
106
+ self._model_ops._session,
107
+ database_name=database_name_id or self._database_name,
108
+ schema_name=schema_name_id or self._schema_name,
109
+ ),
95
110
  model_name=model_name_id,
96
111
  version_name=version_name_id,
97
112
  )
@@ -102,6 +117,8 @@ class ModelManager:
102
117
  if metrics:
103
118
  self._model_ops._metadata_ops.save(
104
119
  metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
120
+ database_name=database_name_id,
121
+ schema_name=schema_name_id,
105
122
  model_name=model_name_id,
106
123
  version_name=version_name_id,
107
124
  statement_params=statement_params,
@@ -115,13 +132,19 @@ class ModelManager:
115
132
  *,
116
133
  statement_params: Optional[Dict[str, Any]] = None,
117
134
  ) -> model_impl.Model:
118
- model_name_id = sql_identifier.SqlIdentifier(model_name)
135
+ database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
119
136
  if self._model_ops.validate_existence(
137
+ database_name=database_name_id,
138
+ schema_name=schema_name_id,
120
139
  model_name=model_name_id,
121
140
  statement_params=statement_params,
122
141
  ):
123
142
  return model_impl.Model._ref(
124
- self._model_ops,
143
+ model_ops.ModelOperator(
144
+ self._model_ops._session,
145
+ database_name=database_name_id or self._database_name,
146
+ schema_name=schema_name_id or self._schema_name,
147
+ ),
125
148
  model_name=model_name_id,
126
149
  )
127
150
  else:
@@ -133,6 +156,8 @@ class ModelManager:
133
156
  statement_params: Optional[Dict[str, Any]] = None,
134
157
  ) -> List[model_impl.Model]:
135
158
  model_names = self._model_ops.list_models_or_versions(
159
+ database_name=None,
160
+ schema_name=None,
136
161
  statement_params=statement_params,
137
162
  )
138
163
  return [
@@ -149,6 +174,8 @@ class ModelManager:
149
174
  statement_params: Optional[Dict[str, Any]] = None,
150
175
  ) -> pd.DataFrame:
151
176
  rows = self._model_ops.show_models_or_versions(
177
+ database_name=None,
178
+ schema_name=None,
152
179
  statement_params=statement_params,
153
180
  )
154
181
  return pd.DataFrame([row.as_dict() for row in rows])
@@ -159,9 +186,11 @@ class ModelManager:
159
186
  *,
160
187
  statement_params: Optional[Dict[str, Any]] = None,
161
188
  ) -> None:
162
- model_name_id = sql_identifier.SqlIdentifier(model_name)
189
+ database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
163
190
 
164
191
  self._model_ops.delete_model_or_version(
192
+ database_name=database_name_id,
193
+ schema_name=schema_name_id,
165
194
  model_name=model_name_id,
166
195
  statement_params=statement_params,
167
196
  )