snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
|
|
59
59
|
|
60
60
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
61
61
|
|
62
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
63
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
64
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
65
|
-
return check
|
66
|
-
|
67
|
-
|
68
62
|
class XGBRFClassifier(BaseTransformer):
|
69
63
|
r"""scikit-learn API for XGBoost random forest classification
|
70
64
|
For more details on this class, see [xgboost.XGBRFClassifier]
|
@@ -487,20 +481,17 @@ class XGBRFClassifier(BaseTransformer):
|
|
487
481
|
self,
|
488
482
|
dataset: DataFrame,
|
489
483
|
inference_method: str,
|
490
|
-
) ->
|
491
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
492
|
-
return the available package that exists in the snowflake anaconda channel
|
484
|
+
) -> None:
|
485
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
493
486
|
|
494
487
|
Args:
|
495
488
|
dataset: snowpark dataframe
|
496
489
|
inference_method: the inference method such as predict, score...
|
497
|
-
|
490
|
+
|
498
491
|
Raises:
|
499
492
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
500
493
|
SnowflakeMLException: If the session is None, raise error
|
501
494
|
|
502
|
-
Returns:
|
503
|
-
A list of available package that exists in the snowflake anaconda channel
|
504
495
|
"""
|
505
496
|
if not self._is_fitted:
|
506
497
|
raise exceptions.SnowflakeMLException(
|
@@ -518,9 +509,7 @@ class XGBRFClassifier(BaseTransformer):
|
|
518
509
|
"Session must not specified for snowpark dataset."
|
519
510
|
),
|
520
511
|
)
|
521
|
-
|
522
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
523
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
512
|
+
|
524
513
|
|
525
514
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
526
515
|
@telemetry.send_api_usage_telemetry(
|
@@ -568,7 +557,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
568
557
|
|
569
558
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
570
559
|
|
571
|
-
self.
|
560
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
561
|
+
self._deps = self._get_dependencies()
|
572
562
|
assert isinstance(
|
573
563
|
dataset._session, Session
|
574
564
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -651,10 +641,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
651
641
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
652
642
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
653
643
|
|
654
|
-
self.
|
655
|
-
|
656
|
-
inference_method=inference_method,
|
657
|
-
)
|
644
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
645
|
+
self._deps = self._get_dependencies()
|
658
646
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
659
647
|
|
660
648
|
transform_kwargs = dict(
|
@@ -721,16 +709,40 @@ class XGBRFClassifier(BaseTransformer):
|
|
721
709
|
self._is_fitted = True
|
722
710
|
return output_result
|
723
711
|
|
712
|
+
|
713
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
714
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
715
|
+
""" Method not supported for this class.
|
724
716
|
|
725
|
-
|
726
|
-
|
727
|
-
|
717
|
+
|
718
|
+
Raises:
|
719
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
720
|
+
|
721
|
+
Args:
|
722
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
723
|
+
Snowpark or Pandas DataFrame.
|
724
|
+
output_cols_prefix: Prefix for the response columns
|
728
725
|
Returns:
|
729
726
|
Transformed dataset.
|
730
727
|
"""
|
731
|
-
self.
|
732
|
-
|
733
|
-
|
728
|
+
self._infer_input_output_cols(dataset)
|
729
|
+
super()._check_dataset_type(dataset)
|
730
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
731
|
+
estimator=self._sklearn_object,
|
732
|
+
dataset=dataset,
|
733
|
+
input_cols=self.input_cols,
|
734
|
+
label_cols=self.label_cols,
|
735
|
+
sample_weight_col=self.sample_weight_col,
|
736
|
+
autogenerated=self._autogenerated,
|
737
|
+
subproject=_SUBPROJECT,
|
738
|
+
)
|
739
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
740
|
+
drop_input_cols=self._drop_input_cols,
|
741
|
+
expected_output_cols_list=self.output_cols,
|
742
|
+
)
|
743
|
+
self._sklearn_object = fitted_estimator
|
744
|
+
self._is_fitted = True
|
745
|
+
return output_result
|
734
746
|
|
735
747
|
|
736
748
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -823,10 +835,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
823
835
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
824
836
|
|
825
837
|
if isinstance(dataset, DataFrame):
|
826
|
-
self.
|
827
|
-
|
828
|
-
inference_method=inference_method,
|
829
|
-
)
|
838
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
839
|
+
self._deps = self._get_dependencies()
|
830
840
|
assert isinstance(
|
831
841
|
dataset._session, Session
|
832
842
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -893,10 +903,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
893
903
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
894
904
|
|
895
905
|
if isinstance(dataset, DataFrame):
|
896
|
-
self.
|
897
|
-
|
898
|
-
inference_method=inference_method,
|
899
|
-
)
|
906
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
907
|
+
self._deps = self._get_dependencies()
|
900
908
|
assert isinstance(
|
901
909
|
dataset._session, Session
|
902
910
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -958,10 +966,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
958
966
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
959
967
|
|
960
968
|
if isinstance(dataset, DataFrame):
|
961
|
-
self.
|
962
|
-
|
963
|
-
inference_method=inference_method,
|
964
|
-
)
|
969
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
970
|
+
self._deps = self._get_dependencies()
|
965
971
|
assert isinstance(
|
966
972
|
dataset._session, Session
|
967
973
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -1027,10 +1033,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
1027
1033
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
1028
1034
|
|
1029
1035
|
if isinstance(dataset, DataFrame):
|
1030
|
-
self.
|
1031
|
-
|
1032
|
-
inference_method=inference_method,
|
1033
|
-
)
|
1036
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1037
|
+
self._deps = self._get_dependencies()
|
1034
1038
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1035
1039
|
transform_kwargs = dict(
|
1036
1040
|
session=dataset._session,
|
@@ -1094,17 +1098,15 @@ class XGBRFClassifier(BaseTransformer):
|
|
1094
1098
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1095
1099
|
|
1096
1100
|
if isinstance(dataset, DataFrame):
|
1097
|
-
self.
|
1098
|
-
|
1099
|
-
inference_method="score",
|
1100
|
-
)
|
1101
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1102
|
+
self._deps = self._get_dependencies()
|
1101
1103
|
selected_cols = self._get_active_columns()
|
1102
1104
|
if len(selected_cols) > 0:
|
1103
1105
|
dataset = dataset.select(selected_cols)
|
1104
1106
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1105
1107
|
transform_kwargs = dict(
|
1106
1108
|
session=dataset._session,
|
1107
|
-
dependencies=
|
1109
|
+
dependencies=self._deps,
|
1108
1110
|
score_sproc_imports=['xgboost'],
|
1109
1111
|
)
|
1110
1112
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1169,11 +1171,8 @@ class XGBRFClassifier(BaseTransformer):
|
|
1169
1171
|
|
1170
1172
|
if isinstance(dataset, DataFrame):
|
1171
1173
|
|
1172
|
-
self.
|
1173
|
-
|
1174
|
-
inference_method=inference_method,
|
1175
|
-
|
1176
|
-
)
|
1174
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1175
|
+
self._deps = self._get_dependencies()
|
1177
1176
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1178
1177
|
transform_kwargs = dict(
|
1179
1178
|
session = dataset._session,
|
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
|
|
59
59
|
|
60
60
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
61
61
|
|
62
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
63
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
64
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
65
|
-
return check
|
66
|
-
|
67
|
-
|
68
62
|
class XGBRFRegressor(BaseTransformer):
|
69
63
|
r"""scikit-learn API for XGBoost random forest regression
|
70
64
|
For more details on this class, see [xgboost.XGBRFRegressor]
|
@@ -487,20 +481,17 @@ class XGBRFRegressor(BaseTransformer):
|
|
487
481
|
self,
|
488
482
|
dataset: DataFrame,
|
489
483
|
inference_method: str,
|
490
|
-
) ->
|
491
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
492
|
-
return the available package that exists in the snowflake anaconda channel
|
484
|
+
) -> None:
|
485
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
493
486
|
|
494
487
|
Args:
|
495
488
|
dataset: snowpark dataframe
|
496
489
|
inference_method: the inference method such as predict, score...
|
497
|
-
|
490
|
+
|
498
491
|
Raises:
|
499
492
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
500
493
|
SnowflakeMLException: If the session is None, raise error
|
501
494
|
|
502
|
-
Returns:
|
503
|
-
A list of available package that exists in the snowflake anaconda channel
|
504
495
|
"""
|
505
496
|
if not self._is_fitted:
|
506
497
|
raise exceptions.SnowflakeMLException(
|
@@ -518,9 +509,7 @@ class XGBRFRegressor(BaseTransformer):
|
|
518
509
|
"Session must not specified for snowpark dataset."
|
519
510
|
),
|
520
511
|
)
|
521
|
-
|
522
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
523
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
512
|
+
|
524
513
|
|
525
514
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
526
515
|
@telemetry.send_api_usage_telemetry(
|
@@ -568,7 +557,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
568
557
|
|
569
558
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
570
559
|
|
571
|
-
self.
|
560
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
561
|
+
self._deps = self._get_dependencies()
|
572
562
|
assert isinstance(
|
573
563
|
dataset._session, Session
|
574
564
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -651,10 +641,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
651
641
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
652
642
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
653
643
|
|
654
|
-
self.
|
655
|
-
|
656
|
-
inference_method=inference_method,
|
657
|
-
)
|
644
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
645
|
+
self._deps = self._get_dependencies()
|
658
646
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
659
647
|
|
660
648
|
transform_kwargs = dict(
|
@@ -721,16 +709,40 @@ class XGBRFRegressor(BaseTransformer):
|
|
721
709
|
self._is_fitted = True
|
722
710
|
return output_result
|
723
711
|
|
712
|
+
|
713
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
714
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
715
|
+
""" Method not supported for this class.
|
724
716
|
|
725
|
-
|
726
|
-
|
727
|
-
|
717
|
+
|
718
|
+
Raises:
|
719
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
720
|
+
|
721
|
+
Args:
|
722
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
723
|
+
Snowpark or Pandas DataFrame.
|
724
|
+
output_cols_prefix: Prefix for the response columns
|
728
725
|
Returns:
|
729
726
|
Transformed dataset.
|
730
727
|
"""
|
731
|
-
self.
|
732
|
-
|
733
|
-
|
728
|
+
self._infer_input_output_cols(dataset)
|
729
|
+
super()._check_dataset_type(dataset)
|
730
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
731
|
+
estimator=self._sklearn_object,
|
732
|
+
dataset=dataset,
|
733
|
+
input_cols=self.input_cols,
|
734
|
+
label_cols=self.label_cols,
|
735
|
+
sample_weight_col=self.sample_weight_col,
|
736
|
+
autogenerated=self._autogenerated,
|
737
|
+
subproject=_SUBPROJECT,
|
738
|
+
)
|
739
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
740
|
+
drop_input_cols=self._drop_input_cols,
|
741
|
+
expected_output_cols_list=self.output_cols,
|
742
|
+
)
|
743
|
+
self._sklearn_object = fitted_estimator
|
744
|
+
self._is_fitted = True
|
745
|
+
return output_result
|
734
746
|
|
735
747
|
|
736
748
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -821,10 +833,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
821
833
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
822
834
|
|
823
835
|
if isinstance(dataset, DataFrame):
|
824
|
-
self.
|
825
|
-
|
826
|
-
inference_method=inference_method,
|
827
|
-
)
|
836
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
837
|
+
self._deps = self._get_dependencies()
|
828
838
|
assert isinstance(
|
829
839
|
dataset._session, Session
|
830
840
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -889,10 +899,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
889
899
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
890
900
|
|
891
901
|
if isinstance(dataset, DataFrame):
|
892
|
-
self.
|
893
|
-
|
894
|
-
inference_method=inference_method,
|
895
|
-
)
|
902
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
903
|
+
self._deps = self._get_dependencies()
|
896
904
|
assert isinstance(
|
897
905
|
dataset._session, Session
|
898
906
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -954,10 +962,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
954
962
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
955
963
|
|
956
964
|
if isinstance(dataset, DataFrame):
|
957
|
-
self.
|
958
|
-
|
959
|
-
inference_method=inference_method,
|
960
|
-
)
|
965
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
966
|
+
self._deps = self._get_dependencies()
|
961
967
|
assert isinstance(
|
962
968
|
dataset._session, Session
|
963
969
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -1023,10 +1029,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
1023
1029
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
1024
1030
|
|
1025
1031
|
if isinstance(dataset, DataFrame):
|
1026
|
-
self.
|
1027
|
-
|
1028
|
-
inference_method=inference_method,
|
1029
|
-
)
|
1032
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1033
|
+
self._deps = self._get_dependencies()
|
1030
1034
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1031
1035
|
transform_kwargs = dict(
|
1032
1036
|
session=dataset._session,
|
@@ -1090,17 +1094,15 @@ class XGBRFRegressor(BaseTransformer):
|
|
1090
1094
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1091
1095
|
|
1092
1096
|
if isinstance(dataset, DataFrame):
|
1093
|
-
self.
|
1094
|
-
|
1095
|
-
inference_method="score",
|
1096
|
-
)
|
1097
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1098
|
+
self._deps = self._get_dependencies()
|
1097
1099
|
selected_cols = self._get_active_columns()
|
1098
1100
|
if len(selected_cols) > 0:
|
1099
1101
|
dataset = dataset.select(selected_cols)
|
1100
1102
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1101
1103
|
transform_kwargs = dict(
|
1102
1104
|
session=dataset._session,
|
1103
|
-
dependencies=
|
1105
|
+
dependencies=self._deps,
|
1104
1106
|
score_sproc_imports=['xgboost'],
|
1105
1107
|
)
|
1106
1108
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1165,11 +1167,8 @@ class XGBRFRegressor(BaseTransformer):
|
|
1165
1167
|
|
1166
1168
|
if isinstance(dataset, DataFrame):
|
1167
1169
|
|
1168
|
-
self.
|
1169
|
-
|
1170
|
-
inference_method=inference_method,
|
1171
|
-
|
1172
|
-
)
|
1170
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1171
|
+
self._deps = self._get_dependencies()
|
1173
1172
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1174
1173
|
transform_kwargs = dict(
|
1175
1174
|
session = dataset._session,
|
@@ -48,20 +48,29 @@ class ModelManager:
|
|
48
48
|
options: Optional[model_types.ModelSaveOption] = None,
|
49
49
|
statement_params: Optional[Dict[str, Any]] = None,
|
50
50
|
) -> model_version_impl.ModelVersion:
|
51
|
-
model_name_id = sql_identifier.
|
51
|
+
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
52
52
|
|
53
53
|
if not version_name:
|
54
54
|
version_name = self._hrid_generator.generate()[1]
|
55
55
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
56
56
|
|
57
57
|
if self._model_ops.validate_existence(
|
58
|
-
|
58
|
+
database_name=database_name_id,
|
59
|
+
schema_name=schema_name_id,
|
60
|
+
model_name=model_name_id,
|
61
|
+
statement_params=statement_params,
|
59
62
|
) and self._model_ops.validate_existence(
|
60
|
-
|
63
|
+
database_name=database_name_id,
|
64
|
+
schema_name=schema_name_id,
|
65
|
+
model_name=model_name_id,
|
66
|
+
version_name=version_name_id,
|
67
|
+
statement_params=statement_params,
|
61
68
|
):
|
62
69
|
raise ValueError(f"Model {model_name} version {version_name} already existed.")
|
63
70
|
|
64
71
|
stage_path = self._model_ops.prepare_model_stage_path(
|
72
|
+
database_name=database_name_id,
|
73
|
+
schema_name=schema_name_id,
|
65
74
|
statement_params=statement_params,
|
66
75
|
)
|
67
76
|
|
@@ -85,13 +94,19 @@ class ModelManager:
|
|
85
94
|
|
86
95
|
self._model_ops.create_from_stage(
|
87
96
|
composed_model=mc,
|
97
|
+
database_name=database_name_id,
|
98
|
+
schema_name=schema_name_id,
|
88
99
|
model_name=model_name_id,
|
89
100
|
version_name=version_name_id,
|
90
101
|
statement_params=statement_params,
|
91
102
|
)
|
92
103
|
|
93
104
|
mv = model_version_impl.ModelVersion._ref(
|
94
|
-
|
105
|
+
model_ops.ModelOperator(
|
106
|
+
self._model_ops._session,
|
107
|
+
database_name=database_name_id or self._database_name,
|
108
|
+
schema_name=schema_name_id or self._schema_name,
|
109
|
+
),
|
95
110
|
model_name=model_name_id,
|
96
111
|
version_name=version_name_id,
|
97
112
|
)
|
@@ -102,6 +117,8 @@ class ModelManager:
|
|
102
117
|
if metrics:
|
103
118
|
self._model_ops._metadata_ops.save(
|
104
119
|
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
120
|
+
database_name=database_name_id,
|
121
|
+
schema_name=schema_name_id,
|
105
122
|
model_name=model_name_id,
|
106
123
|
version_name=version_name_id,
|
107
124
|
statement_params=statement_params,
|
@@ -115,13 +132,19 @@ class ModelManager:
|
|
115
132
|
*,
|
116
133
|
statement_params: Optional[Dict[str, Any]] = None,
|
117
134
|
) -> model_impl.Model:
|
118
|
-
model_name_id = sql_identifier.
|
135
|
+
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
119
136
|
if self._model_ops.validate_existence(
|
137
|
+
database_name=database_name_id,
|
138
|
+
schema_name=schema_name_id,
|
120
139
|
model_name=model_name_id,
|
121
140
|
statement_params=statement_params,
|
122
141
|
):
|
123
142
|
return model_impl.Model._ref(
|
124
|
-
|
143
|
+
model_ops.ModelOperator(
|
144
|
+
self._model_ops._session,
|
145
|
+
database_name=database_name_id or self._database_name,
|
146
|
+
schema_name=schema_name_id or self._schema_name,
|
147
|
+
),
|
125
148
|
model_name=model_name_id,
|
126
149
|
)
|
127
150
|
else:
|
@@ -133,6 +156,8 @@ class ModelManager:
|
|
133
156
|
statement_params: Optional[Dict[str, Any]] = None,
|
134
157
|
) -> List[model_impl.Model]:
|
135
158
|
model_names = self._model_ops.list_models_or_versions(
|
159
|
+
database_name=None,
|
160
|
+
schema_name=None,
|
136
161
|
statement_params=statement_params,
|
137
162
|
)
|
138
163
|
return [
|
@@ -149,6 +174,8 @@ class ModelManager:
|
|
149
174
|
statement_params: Optional[Dict[str, Any]] = None,
|
150
175
|
) -> pd.DataFrame:
|
151
176
|
rows = self._model_ops.show_models_or_versions(
|
177
|
+
database_name=None,
|
178
|
+
schema_name=None,
|
152
179
|
statement_params=statement_params,
|
153
180
|
)
|
154
181
|
return pd.DataFrame([row.as_dict() for row in rows])
|
@@ -159,9 +186,11 @@ class ModelManager:
|
|
159
186
|
*,
|
160
187
|
statement_params: Optional[Dict[str, Any]] = None,
|
161
188
|
) -> None:
|
162
|
-
model_name_id = sql_identifier.
|
189
|
+
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
163
190
|
|
164
191
|
self._model_ops.delete_model_or_version(
|
192
|
+
database_name=database_name_id,
|
193
|
+
schema_name=schema_name_id,
|
165
194
|
model_name=model_name_id,
|
166
195
|
statement_params=statement_params,
|
167
196
|
)
|