snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
32
32
|
BatchInferenceKwargsTypedDict,
|
33
33
|
ScoreKwargsTypedDict
|
34
34
|
)
|
35
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
36
|
+
from snowflake.ml.model.model_signature import (
|
37
|
+
BaseFeatureSpec,
|
38
|
+
DataType,
|
39
|
+
FeatureSpec,
|
40
|
+
ModelSignature,
|
41
|
+
_infer_signature,
|
42
|
+
_rename_signature_with_snowflake_identifiers,
|
43
|
+
)
|
35
44
|
|
36
45
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
37
46
|
|
@@ -42,16 +51,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
42
51
|
validate_sklearn_args,
|
43
52
|
)
|
44
53
|
|
45
|
-
from snowflake.ml.model.model_signature import (
|
46
|
-
DataType,
|
47
|
-
FeatureSpec,
|
48
|
-
ModelSignature,
|
49
|
-
_infer_signature,
|
50
|
-
_rename_signature_with_snowflake_identifiers,
|
51
|
-
BaseFeatureSpec,
|
52
|
-
)
|
53
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
54
|
-
|
55
54
|
_PROJECT = "ModelDevelopment"
|
56
55
|
# Derive subproject from module name by removing "sklearn"
|
57
56
|
# and converting module name from underscore to CamelCase
|
@@ -60,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
|
|
60
59
|
|
61
60
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
61
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
62
|
class XGBRegressor(BaseTransformer):
|
70
63
|
r"""Implementation of the scikit-learn API for XGBoost regression
|
71
64
|
For more details on this class, see [xgboost.XGBRegressor]
|
@@ -421,12 +414,7 @@ class XGBRegressor(BaseTransformer):
|
|
421
414
|
)
|
422
415
|
return selected_cols
|
423
416
|
|
424
|
-
|
425
|
-
project=_PROJECT,
|
426
|
-
subproject=_SUBPROJECT,
|
427
|
-
custom_tags=dict([("autogen", True)]),
|
428
|
-
)
|
429
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "XGBRegressor":
|
417
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "XGBRegressor":
|
430
418
|
"""Fit gradient boosting model
|
431
419
|
For more details on this function, see [xgboost.XGBRegressor.fit]
|
432
420
|
(https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBRegressor.fit)
|
@@ -453,12 +441,14 @@ class XGBRegressor(BaseTransformer):
|
|
453
441
|
|
454
442
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
455
443
|
|
456
|
-
|
444
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
457
445
|
if SNOWML_SPROC_ENV in os.environ:
|
458
446
|
statement_params = telemetry.get_function_usage_statement_params(
|
459
447
|
project=_PROJECT,
|
460
448
|
subproject=_SUBPROJECT,
|
461
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
449
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
450
|
+
inspect.currentframe(), XGBRegressor.__class__.__name__
|
451
|
+
),
|
462
452
|
api_calls=[Session.call],
|
463
453
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
464
454
|
)
|
@@ -479,27 +469,24 @@ class XGBRegressor(BaseTransformer):
|
|
479
469
|
)
|
480
470
|
self._sklearn_object = model_trainer.train()
|
481
471
|
self._is_fitted = True
|
482
|
-
self.
|
472
|
+
self._generate_model_signatures(dataset)
|
483
473
|
return self
|
484
474
|
|
485
475
|
def _batch_inference_validate_snowpark(
|
486
476
|
self,
|
487
477
|
dataset: DataFrame,
|
488
478
|
inference_method: str,
|
489
|
-
) ->
|
490
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
491
|
-
return the available package that exists in the snowflake anaconda channel
|
479
|
+
) -> None:
|
480
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
492
481
|
|
493
482
|
Args:
|
494
483
|
dataset: snowpark dataframe
|
495
484
|
inference_method: the inference method such as predict, score...
|
496
|
-
|
485
|
+
|
497
486
|
Raises:
|
498
487
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
499
488
|
SnowflakeMLException: If the session is None, raise error
|
500
489
|
|
501
|
-
Returns:
|
502
|
-
A list of available package that exists in the snowflake anaconda channel
|
503
490
|
"""
|
504
491
|
if not self._is_fitted:
|
505
492
|
raise exceptions.SnowflakeMLException(
|
@@ -517,9 +504,7 @@ class XGBRegressor(BaseTransformer):
|
|
517
504
|
"Session must not specified for snowpark dataset."
|
518
505
|
),
|
519
506
|
)
|
520
|
-
|
521
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
522
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
507
|
+
|
523
508
|
|
524
509
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
525
510
|
@telemetry.send_api_usage_telemetry(
|
@@ -555,7 +540,9 @@ class XGBRegressor(BaseTransformer):
|
|
555
540
|
# when it is classifier, infer the datatype from label columns
|
556
541
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
557
542
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
558
|
-
label_cols_signatures = [
|
543
|
+
label_cols_signatures = [
|
544
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
545
|
+
]
|
559
546
|
if len(label_cols_signatures) == 0:
|
560
547
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
561
548
|
raise exceptions.SnowflakeMLException(
|
@@ -563,25 +550,23 @@ class XGBRegressor(BaseTransformer):
|
|
563
550
|
original_exception=ValueError(error_str),
|
564
551
|
)
|
565
552
|
|
566
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
567
|
-
label_cols_signatures[0].as_snowpark_type()
|
568
|
-
)
|
553
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
569
554
|
|
570
|
-
self.
|
571
|
-
|
555
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
556
|
+
self._deps = self._get_dependencies()
|
557
|
+
assert isinstance(
|
558
|
+
dataset._session, Session
|
559
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
572
560
|
|
573
561
|
transform_kwargs = dict(
|
574
|
-
session
|
575
|
-
dependencies
|
576
|
-
drop_input_cols
|
577
|
-
expected_output_cols_type
|
562
|
+
session=dataset._session,
|
563
|
+
dependencies=self._deps,
|
564
|
+
drop_input_cols=self._drop_input_cols,
|
565
|
+
expected_output_cols_type=expected_type_inferred,
|
578
566
|
)
|
579
567
|
|
580
568
|
elif isinstance(dataset, pd.DataFrame):
|
581
|
-
transform_kwargs = dict(
|
582
|
-
snowpark_input_cols = self._snowpark_cols,
|
583
|
-
drop_input_cols = self._drop_input_cols
|
584
|
-
)
|
569
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
585
570
|
|
586
571
|
transform_handlers = ModelTransformerBuilder.build(
|
587
572
|
dataset=dataset,
|
@@ -621,7 +606,7 @@ class XGBRegressor(BaseTransformer):
|
|
621
606
|
Transformed dataset.
|
622
607
|
"""
|
623
608
|
super()._check_dataset_type(dataset)
|
624
|
-
inference_method="transform"
|
609
|
+
inference_method = "transform"
|
625
610
|
|
626
611
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
627
612
|
# are specific to the type of dataset used.
|
@@ -651,24 +636,19 @@ class XGBRegressor(BaseTransformer):
|
|
651
636
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
652
637
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
653
638
|
|
654
|
-
self.
|
655
|
-
|
656
|
-
inference_method=inference_method,
|
657
|
-
)
|
639
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
640
|
+
self._deps = self._get_dependencies()
|
658
641
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
659
642
|
|
660
643
|
transform_kwargs = dict(
|
661
|
-
session
|
662
|
-
dependencies
|
663
|
-
drop_input_cols
|
664
|
-
expected_output_cols_type
|
644
|
+
session=dataset._session,
|
645
|
+
dependencies=self._deps,
|
646
|
+
drop_input_cols=self._drop_input_cols,
|
647
|
+
expected_output_cols_type=expected_dtype,
|
665
648
|
)
|
666
649
|
|
667
650
|
elif isinstance(dataset, pd.DataFrame):
|
668
|
-
transform_kwargs = dict(
|
669
|
-
snowpark_input_cols = self._snowpark_cols,
|
670
|
-
drop_input_cols = self._drop_input_cols
|
671
|
-
)
|
651
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
672
652
|
|
673
653
|
transform_handlers = ModelTransformerBuilder.build(
|
674
654
|
dataset=dataset,
|
@@ -687,7 +667,11 @@ class XGBRegressor(BaseTransformer):
|
|
687
667
|
return output_df
|
688
668
|
|
689
669
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
690
|
-
def fit_predict(
|
670
|
+
def fit_predict(
|
671
|
+
self,
|
672
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
673
|
+
output_cols_prefix: str = "fit_predict_",
|
674
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
691
675
|
""" Method not supported for this class.
|
692
676
|
|
693
677
|
|
@@ -712,22 +696,104 @@ class XGBRegressor(BaseTransformer):
|
|
712
696
|
)
|
713
697
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
714
698
|
drop_input_cols=self._drop_input_cols,
|
715
|
-
expected_output_cols_list=
|
699
|
+
expected_output_cols_list=(
|
700
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
701
|
+
),
|
716
702
|
)
|
717
703
|
self._sklearn_object = fitted_estimator
|
718
704
|
self._is_fitted = True
|
719
705
|
return output_result
|
720
706
|
|
707
|
+
|
708
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
709
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
710
|
+
""" Method not supported for this class.
|
711
|
+
|
721
712
|
|
722
|
-
|
723
|
-
|
724
|
-
|
713
|
+
Raises:
|
714
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
715
|
+
|
716
|
+
Args:
|
717
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
718
|
+
Snowpark or Pandas DataFrame.
|
719
|
+
output_cols_prefix: Prefix for the response columns
|
725
720
|
Returns:
|
726
721
|
Transformed dataset.
|
727
722
|
"""
|
728
|
-
self.
|
729
|
-
|
730
|
-
|
723
|
+
self._infer_input_output_cols(dataset)
|
724
|
+
super()._check_dataset_type(dataset)
|
725
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
726
|
+
estimator=self._sklearn_object,
|
727
|
+
dataset=dataset,
|
728
|
+
input_cols=self.input_cols,
|
729
|
+
label_cols=self.label_cols,
|
730
|
+
sample_weight_col=self.sample_weight_col,
|
731
|
+
autogenerated=self._autogenerated,
|
732
|
+
subproject=_SUBPROJECT,
|
733
|
+
)
|
734
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
735
|
+
drop_input_cols=self._drop_input_cols,
|
736
|
+
expected_output_cols_list=self.output_cols,
|
737
|
+
)
|
738
|
+
self._sklearn_object = fitted_estimator
|
739
|
+
self._is_fitted = True
|
740
|
+
return output_result
|
741
|
+
|
742
|
+
|
743
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
744
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
745
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
746
|
+
"""
|
747
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
748
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
749
|
+
if output_cols:
|
750
|
+
output_cols = [
|
751
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
752
|
+
for c in output_cols
|
753
|
+
]
|
754
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
755
|
+
output_cols = [output_cols_prefix]
|
756
|
+
elif self._sklearn_object is not None:
|
757
|
+
classes = self._sklearn_object.classes_
|
758
|
+
if isinstance(classes, numpy.ndarray):
|
759
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
760
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
761
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
762
|
+
output_cols = []
|
763
|
+
for i, cl in enumerate(classes):
|
764
|
+
# For binary classification, there is only one output column for each class
|
765
|
+
# ndarray as the two classes are complementary.
|
766
|
+
if len(cl) == 2:
|
767
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
768
|
+
else:
|
769
|
+
output_cols.extend([
|
770
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
771
|
+
])
|
772
|
+
else:
|
773
|
+
output_cols = []
|
774
|
+
|
775
|
+
# Make sure column names are valid snowflake identifiers.
|
776
|
+
assert output_cols is not None # Make MyPy happy
|
777
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
778
|
+
|
779
|
+
return rv
|
780
|
+
|
781
|
+
def _align_expected_output_names(
|
782
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
783
|
+
) -> List[str]:
|
784
|
+
# in case the inferred output column names dimension is different
|
785
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
786
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
787
|
+
output_df_columns = list(output_df_pd.columns)
|
788
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
789
|
+
if self.sample_weight_col:
|
790
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
791
|
+
# if the dimension of inferred output column names is correct; use it
|
792
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
793
|
+
return expected_output_cols_list
|
794
|
+
# otherwise, use the sklearn estimator's output
|
795
|
+
else:
|
796
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
731
797
|
|
732
798
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
733
799
|
@telemetry.send_api_usage_telemetry(
|
@@ -759,24 +825,26 @@ class XGBRegressor(BaseTransformer):
|
|
759
825
|
# are specific to the type of dataset used.
|
760
826
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
761
827
|
|
828
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
829
|
+
|
762
830
|
if isinstance(dataset, DataFrame):
|
763
|
-
self.
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
831
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
832
|
+
self._deps = self._get_dependencies()
|
833
|
+
assert isinstance(
|
834
|
+
dataset._session, Session
|
835
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
768
836
|
transform_kwargs = dict(
|
769
837
|
session=dataset._session,
|
770
838
|
dependencies=self._deps,
|
771
|
-
drop_input_cols
|
839
|
+
drop_input_cols=self._drop_input_cols,
|
772
840
|
expected_output_cols_type="float",
|
773
841
|
)
|
842
|
+
expected_output_cols = self._align_expected_output_names(
|
843
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
844
|
+
)
|
774
845
|
|
775
846
|
elif isinstance(dataset, pd.DataFrame):
|
776
|
-
transform_kwargs = dict(
|
777
|
-
snowpark_input_cols = self._snowpark_cols,
|
778
|
-
drop_input_cols = self._drop_input_cols
|
779
|
-
)
|
847
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
780
848
|
|
781
849
|
transform_handlers = ModelTransformerBuilder.build(
|
782
850
|
dataset=dataset,
|
@@ -788,7 +856,7 @@ class XGBRegressor(BaseTransformer):
|
|
788
856
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
789
857
|
inference_method=inference_method,
|
790
858
|
input_cols=self.input_cols,
|
791
|
-
expected_output_cols=
|
859
|
+
expected_output_cols=expected_output_cols,
|
792
860
|
**transform_kwargs
|
793
861
|
)
|
794
862
|
return output_df
|
@@ -818,29 +886,30 @@ class XGBRegressor(BaseTransformer):
|
|
818
886
|
Output dataset with log probability of the sample for each class in the model.
|
819
887
|
"""
|
820
888
|
super()._check_dataset_type(dataset)
|
821
|
-
inference_method="predict_log_proba"
|
889
|
+
inference_method = "predict_log_proba"
|
890
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
822
891
|
|
823
892
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
824
893
|
# are specific to the type of dataset used.
|
825
894
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
826
895
|
|
827
896
|
if isinstance(dataset, DataFrame):
|
828
|
-
self.
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
897
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
898
|
+
self._deps = self._get_dependencies()
|
899
|
+
assert isinstance(
|
900
|
+
dataset._session, Session
|
901
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
833
902
|
transform_kwargs = dict(
|
834
903
|
session=dataset._session,
|
835
904
|
dependencies=self._deps,
|
836
|
-
drop_input_cols
|
905
|
+
drop_input_cols=self._drop_input_cols,
|
837
906
|
expected_output_cols_type="float",
|
838
907
|
)
|
908
|
+
expected_output_cols = self._align_expected_output_names(
|
909
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
910
|
+
)
|
839
911
|
elif isinstance(dataset, pd.DataFrame):
|
840
|
-
transform_kwargs = dict(
|
841
|
-
snowpark_input_cols = self._snowpark_cols,
|
842
|
-
drop_input_cols = self._drop_input_cols
|
843
|
-
)
|
912
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
844
913
|
|
845
914
|
transform_handlers = ModelTransformerBuilder.build(
|
846
915
|
dataset=dataset,
|
@@ -853,7 +922,7 @@ class XGBRegressor(BaseTransformer):
|
|
853
922
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
854
923
|
inference_method=inference_method,
|
855
924
|
input_cols=self.input_cols,
|
856
|
-
expected_output_cols=
|
925
|
+
expected_output_cols=expected_output_cols,
|
857
926
|
**transform_kwargs
|
858
927
|
)
|
859
928
|
return output_df
|
@@ -879,30 +948,32 @@ class XGBRegressor(BaseTransformer):
|
|
879
948
|
Output dataset with results of the decision function for the samples in input dataset.
|
880
949
|
"""
|
881
950
|
super()._check_dataset_type(dataset)
|
882
|
-
inference_method="decision_function"
|
951
|
+
inference_method = "decision_function"
|
883
952
|
|
884
953
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
885
954
|
# are specific to the type of dataset used.
|
886
955
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
887
956
|
|
957
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
958
|
+
|
888
959
|
if isinstance(dataset, DataFrame):
|
889
|
-
self.
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
960
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
961
|
+
self._deps = self._get_dependencies()
|
962
|
+
assert isinstance(
|
963
|
+
dataset._session, Session
|
964
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
894
965
|
transform_kwargs = dict(
|
895
966
|
session=dataset._session,
|
896
967
|
dependencies=self._deps,
|
897
|
-
drop_input_cols
|
968
|
+
drop_input_cols=self._drop_input_cols,
|
898
969
|
expected_output_cols_type="float",
|
899
970
|
)
|
971
|
+
expected_output_cols = self._align_expected_output_names(
|
972
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
973
|
+
)
|
900
974
|
|
901
975
|
elif isinstance(dataset, pd.DataFrame):
|
902
|
-
transform_kwargs = dict(
|
903
|
-
snowpark_input_cols = self._snowpark_cols,
|
904
|
-
drop_input_cols = self._drop_input_cols
|
905
|
-
)
|
976
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
906
977
|
|
907
978
|
transform_handlers = ModelTransformerBuilder.build(
|
908
979
|
dataset=dataset,
|
@@ -915,7 +986,7 @@ class XGBRegressor(BaseTransformer):
|
|
915
986
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
916
987
|
inference_method=inference_method,
|
917
988
|
input_cols=self.input_cols,
|
918
|
-
expected_output_cols=
|
989
|
+
expected_output_cols=expected_output_cols,
|
919
990
|
**transform_kwargs
|
920
991
|
)
|
921
992
|
return output_df
|
@@ -944,17 +1015,17 @@ class XGBRegressor(BaseTransformer):
|
|
944
1015
|
Output dataset with probability of the sample for each class in the model.
|
945
1016
|
"""
|
946
1017
|
super()._check_dataset_type(dataset)
|
947
|
-
inference_method="score_samples"
|
1018
|
+
inference_method = "score_samples"
|
948
1019
|
|
949
1020
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
950
1021
|
# are specific to the type of dataset used.
|
951
1022
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
952
1023
|
|
1024
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
1025
|
+
|
953
1026
|
if isinstance(dataset, DataFrame):
|
954
|
-
self.
|
955
|
-
|
956
|
-
inference_method=inference_method,
|
957
|
-
)
|
1027
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1028
|
+
self._deps = self._get_dependencies()
|
958
1029
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
959
1030
|
transform_kwargs = dict(
|
960
1031
|
session=dataset._session,
|
@@ -962,6 +1033,9 @@ class XGBRegressor(BaseTransformer):
|
|
962
1033
|
drop_input_cols = self._drop_input_cols,
|
963
1034
|
expected_output_cols_type="float",
|
964
1035
|
)
|
1036
|
+
expected_output_cols = self._align_expected_output_names(
|
1037
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
1038
|
+
)
|
965
1039
|
|
966
1040
|
elif isinstance(dataset, pd.DataFrame):
|
967
1041
|
transform_kwargs = dict(
|
@@ -980,7 +1054,7 @@ class XGBRegressor(BaseTransformer):
|
|
980
1054
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
981
1055
|
inference_method=inference_method,
|
982
1056
|
input_cols=self.input_cols,
|
983
|
-
expected_output_cols=
|
1057
|
+
expected_output_cols=expected_output_cols,
|
984
1058
|
**transform_kwargs
|
985
1059
|
)
|
986
1060
|
return output_df
|
@@ -1015,17 +1089,15 @@ class XGBRegressor(BaseTransformer):
|
|
1015
1089
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1016
1090
|
|
1017
1091
|
if isinstance(dataset, DataFrame):
|
1018
|
-
self.
|
1019
|
-
|
1020
|
-
inference_method="score",
|
1021
|
-
)
|
1092
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1093
|
+
self._deps = self._get_dependencies()
|
1022
1094
|
selected_cols = self._get_active_columns()
|
1023
1095
|
if len(selected_cols) > 0:
|
1024
1096
|
dataset = dataset.select(selected_cols)
|
1025
1097
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1026
1098
|
transform_kwargs = dict(
|
1027
1099
|
session=dataset._session,
|
1028
|
-
dependencies=
|
1100
|
+
dependencies=self._deps,
|
1029
1101
|
score_sproc_imports=['xgboost'],
|
1030
1102
|
)
|
1031
1103
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1090,11 +1162,8 @@ class XGBRegressor(BaseTransformer):
|
|
1090
1162
|
|
1091
1163
|
if isinstance(dataset, DataFrame):
|
1092
1164
|
|
1093
|
-
self.
|
1094
|
-
|
1095
|
-
inference_method=inference_method,
|
1096
|
-
|
1097
|
-
)
|
1165
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1166
|
+
self._deps = self._get_dependencies()
|
1098
1167
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1099
1168
|
transform_kwargs = dict(
|
1100
1169
|
session = dataset._session,
|
@@ -1127,50 +1196,84 @@ class XGBRegressor(BaseTransformer):
|
|
1127
1196
|
)
|
1128
1197
|
return output_df
|
1129
1198
|
|
1199
|
+
|
1200
|
+
|
1201
|
+
def to_xgboost(self) -> Any:
|
1202
|
+
"""Get xgboost.XGBRegressor object.
|
1203
|
+
"""
|
1204
|
+
if self._sklearn_object is None:
|
1205
|
+
self._sklearn_object = self._create_sklearn_object()
|
1206
|
+
return self._sklearn_object
|
1207
|
+
|
1208
|
+
def to_sklearn(self) -> Any:
|
1209
|
+
raise exceptions.SnowflakeMLException(
|
1210
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1211
|
+
original_exception=AttributeError(
|
1212
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1213
|
+
"to_sklearn()",
|
1214
|
+
"to_xgboost()"
|
1215
|
+
)
|
1216
|
+
),
|
1217
|
+
)
|
1218
|
+
|
1219
|
+
def to_lightgbm(self) -> Any:
|
1220
|
+
raise exceptions.SnowflakeMLException(
|
1221
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1222
|
+
original_exception=AttributeError(
|
1223
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1224
|
+
"to_lightgbm()",
|
1225
|
+
"to_xgboost()"
|
1226
|
+
)
|
1227
|
+
),
|
1228
|
+
)
|
1229
|
+
|
1230
|
+
def _get_dependencies(self) -> List[str]:
|
1231
|
+
return self._deps
|
1232
|
+
|
1130
1233
|
|
1131
|
-
def
|
1234
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1132
1235
|
self._model_signature_dict = dict()
|
1133
1236
|
|
1134
1237
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1135
1238
|
|
1136
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1239
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1137
1240
|
outputs: List[BaseFeatureSpec] = []
|
1138
1241
|
if hasattr(self, "predict"):
|
1139
1242
|
# keep mypy happy
|
1140
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1243
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1141
1244
|
# For classifier, the type of predict is the same as the type of label
|
1142
|
-
if self._sklearn_object._estimator_type ==
|
1143
|
-
|
1245
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1246
|
+
# label columns is the desired type for output
|
1144
1247
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1145
1248
|
# rename the output columns
|
1146
1249
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1147
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1148
|
-
|
1149
|
-
|
1250
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1251
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1252
|
+
)
|
1150
1253
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1151
1254
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1152
|
-
# Clusterer returns int64 cluster labels.
|
1255
|
+
# Clusterer returns int64 cluster labels.
|
1153
1256
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1154
1257
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1155
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1258
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1259
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1260
|
+
)
|
1261
|
+
|
1159
1262
|
# For regressor, the type of predict is float64
|
1160
|
-
elif self._sklearn_object._estimator_type ==
|
1263
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1161
1264
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1162
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1265
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1266
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1267
|
+
)
|
1268
|
+
|
1166
1269
|
for prob_func in PROB_FUNCTIONS:
|
1167
1270
|
if hasattr(self, prob_func):
|
1168
1271
|
output_cols_prefix: str = f"{prob_func}_"
|
1169
1272
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1170
1273
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1171
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1172
|
-
|
1173
|
-
|
1274
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1275
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1276
|
+
)
|
1174
1277
|
|
1175
1278
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1176
1279
|
items = list(self._model_signature_dict.items())
|
@@ -1183,10 +1286,10 @@ class XGBRegressor(BaseTransformer):
|
|
1183
1286
|
"""Returns model signature of current class.
|
1184
1287
|
|
1185
1288
|
Raises:
|
1186
|
-
|
1289
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1187
1290
|
|
1188
1291
|
Returns:
|
1189
|
-
Dict
|
1292
|
+
Dict with each method and its input output signature
|
1190
1293
|
"""
|
1191
1294
|
if self._model_signature_dict is None:
|
1192
1295
|
raise exceptions.SnowflakeMLException(
|
@@ -1194,35 +1297,3 @@ class XGBRegressor(BaseTransformer):
|
|
1194
1297
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1195
1298
|
)
|
1196
1299
|
return self._model_signature_dict
|
1197
|
-
|
1198
|
-
def to_xgboost(self) -> Any:
|
1199
|
-
"""Get xgboost.XGBRegressor object.
|
1200
|
-
"""
|
1201
|
-
if self._sklearn_object is None:
|
1202
|
-
self._sklearn_object = self._create_sklearn_object()
|
1203
|
-
return self._sklearn_object
|
1204
|
-
|
1205
|
-
def to_sklearn(self) -> Any:
|
1206
|
-
raise exceptions.SnowflakeMLException(
|
1207
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1208
|
-
original_exception=AttributeError(
|
1209
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1210
|
-
"to_sklearn()",
|
1211
|
-
"to_xgboost()"
|
1212
|
-
)
|
1213
|
-
),
|
1214
|
-
)
|
1215
|
-
|
1216
|
-
def to_lightgbm(self) -> Any:
|
1217
|
-
raise exceptions.SnowflakeMLException(
|
1218
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1219
|
-
original_exception=AttributeError(
|
1220
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1221
|
-
"to_lightgbm()",
|
1222
|
-
"to_xgboost()"
|
1223
|
-
)
|
1224
|
-
),
|
1225
|
-
)
|
1226
|
-
|
1227
|
-
def _get_dependencies(self) -> List[str]:
|
1228
|
-
return self._deps
|