snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GradientBoostingRegressor(BaseTransformer):
|
70
64
|
r"""Gradient Boosting for regression
|
71
65
|
For more details on this class, see [sklearn.ensemble.GradientBoostingRegressor]
|
@@ -461,20 +455,17 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
461
455
|
self,
|
462
456
|
dataset: DataFrame,
|
463
457
|
inference_method: str,
|
464
|
-
) ->
|
465
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
466
|
-
return the available package that exists in the snowflake anaconda channel
|
458
|
+
) -> None:
|
459
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
467
460
|
|
468
461
|
Args:
|
469
462
|
dataset: snowpark dataframe
|
470
463
|
inference_method: the inference method such as predict, score...
|
471
|
-
|
464
|
+
|
472
465
|
Raises:
|
473
466
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
474
467
|
SnowflakeMLException: If the session is None, raise error
|
475
468
|
|
476
|
-
Returns:
|
477
|
-
A list of available package that exists in the snowflake anaconda channel
|
478
469
|
"""
|
479
470
|
if not self._is_fitted:
|
480
471
|
raise exceptions.SnowflakeMLException(
|
@@ -492,9 +483,7 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
492
483
|
"Session must not specified for snowpark dataset."
|
493
484
|
),
|
494
485
|
)
|
495
|
-
|
496
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
497
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
486
|
+
|
498
487
|
|
499
488
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
500
489
|
@telemetry.send_api_usage_telemetry(
|
@@ -542,7 +531,8 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
542
531
|
|
543
532
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
544
533
|
|
545
|
-
self.
|
534
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
535
|
+
self._deps = self._get_dependencies()
|
546
536
|
assert isinstance(
|
547
537
|
dataset._session, Session
|
548
538
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -625,10 +615,8 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
625
615
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
626
616
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
627
617
|
|
628
|
-
self.
|
629
|
-
|
630
|
-
inference_method=inference_method,
|
631
|
-
)
|
618
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
619
|
+
self._deps = self._get_dependencies()
|
632
620
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
633
621
|
|
634
622
|
transform_kwargs = dict(
|
@@ -695,16 +683,40 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
695
683
|
self._is_fitted = True
|
696
684
|
return output_result
|
697
685
|
|
686
|
+
|
687
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
688
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
689
|
+
""" Method not supported for this class.
|
698
690
|
|
699
|
-
|
700
|
-
|
701
|
-
|
691
|
+
|
692
|
+
Raises:
|
693
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
694
|
+
|
695
|
+
Args:
|
696
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
697
|
+
Snowpark or Pandas DataFrame.
|
698
|
+
output_cols_prefix: Prefix for the response columns
|
702
699
|
Returns:
|
703
700
|
Transformed dataset.
|
704
701
|
"""
|
705
|
-
self.
|
706
|
-
|
707
|
-
|
702
|
+
self._infer_input_output_cols(dataset)
|
703
|
+
super()._check_dataset_type(dataset)
|
704
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
705
|
+
estimator=self._sklearn_object,
|
706
|
+
dataset=dataset,
|
707
|
+
input_cols=self.input_cols,
|
708
|
+
label_cols=self.label_cols,
|
709
|
+
sample_weight_col=self.sample_weight_col,
|
710
|
+
autogenerated=self._autogenerated,
|
711
|
+
subproject=_SUBPROJECT,
|
712
|
+
)
|
713
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
714
|
+
drop_input_cols=self._drop_input_cols,
|
715
|
+
expected_output_cols_list=self.output_cols,
|
716
|
+
)
|
717
|
+
self._sklearn_object = fitted_estimator
|
718
|
+
self._is_fitted = True
|
719
|
+
return output_result
|
708
720
|
|
709
721
|
|
710
722
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -795,10 +807,8 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
795
807
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
796
808
|
|
797
809
|
if isinstance(dataset, DataFrame):
|
798
|
-
self.
|
799
|
-
|
800
|
-
inference_method=inference_method,
|
801
|
-
)
|
810
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
811
|
+
self._deps = self._get_dependencies()
|
802
812
|
assert isinstance(
|
803
813
|
dataset._session, Session
|
804
814
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -863,10 +873,8 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
863
873
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
864
874
|
|
865
875
|
if isinstance(dataset, DataFrame):
|
866
|
-
self.
|
867
|
-
|
868
|
-
inference_method=inference_method,
|
869
|
-
)
|
876
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
877
|
+
self._deps = self._get_dependencies()
|
870
878
|
assert isinstance(
|
871
879
|
dataset._session, Session
|
872
880
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -928,10 +936,8 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
928
936
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
929
937
|
|
930
938
|
if isinstance(dataset, DataFrame):
|
931
|
-
self.
|
932
|
-
|
933
|
-
inference_method=inference_method,
|
934
|
-
)
|
939
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
940
|
+
self._deps = self._get_dependencies()
|
935
941
|
assert isinstance(
|
936
942
|
dataset._session, Session
|
937
943
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -997,10 +1003,8 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
997
1003
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
998
1004
|
|
999
1005
|
if isinstance(dataset, DataFrame):
|
1000
|
-
self.
|
1001
|
-
|
1002
|
-
inference_method=inference_method,
|
1003
|
-
)
|
1006
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1007
|
+
self._deps = self._get_dependencies()
|
1004
1008
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1005
1009
|
transform_kwargs = dict(
|
1006
1010
|
session=dataset._session,
|
@@ -1064,17 +1068,15 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
1064
1068
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1065
1069
|
|
1066
1070
|
if isinstance(dataset, DataFrame):
|
1067
|
-
self.
|
1068
|
-
|
1069
|
-
inference_method="score",
|
1070
|
-
)
|
1071
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1072
|
+
self._deps = self._get_dependencies()
|
1071
1073
|
selected_cols = self._get_active_columns()
|
1072
1074
|
if len(selected_cols) > 0:
|
1073
1075
|
dataset = dataset.select(selected_cols)
|
1074
1076
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1075
1077
|
transform_kwargs = dict(
|
1076
1078
|
session=dataset._session,
|
1077
|
-
dependencies=
|
1079
|
+
dependencies=self._deps,
|
1078
1080
|
score_sproc_imports=['sklearn'],
|
1079
1081
|
)
|
1080
1082
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1139,11 +1141,8 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
1139
1141
|
|
1140
1142
|
if isinstance(dataset, DataFrame):
|
1141
1143
|
|
1142
|
-
self.
|
1143
|
-
|
1144
|
-
inference_method=inference_method,
|
1145
|
-
|
1146
|
-
)
|
1144
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1145
|
+
self._deps = self._get_dependencies()
|
1147
1146
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1148
1147
|
transform_kwargs = dict(
|
1149
1148
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class HistGradientBoostingClassifier(BaseTransformer):
|
70
64
|
r"""Histogram-based Gradient Boosting Classification Tree
|
71
65
|
For more details on this class, see [sklearn.ensemble.HistGradientBoostingClassifier]
|
@@ -433,20 +427,17 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
433
427
|
self,
|
434
428
|
dataset: DataFrame,
|
435
429
|
inference_method: str,
|
436
|
-
) ->
|
437
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
438
|
-
return the available package that exists in the snowflake anaconda channel
|
430
|
+
) -> None:
|
431
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
439
432
|
|
440
433
|
Args:
|
441
434
|
dataset: snowpark dataframe
|
442
435
|
inference_method: the inference method such as predict, score...
|
443
|
-
|
436
|
+
|
444
437
|
Raises:
|
445
438
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
446
439
|
SnowflakeMLException: If the session is None, raise error
|
447
440
|
|
448
|
-
Returns:
|
449
|
-
A list of available package that exists in the snowflake anaconda channel
|
450
441
|
"""
|
451
442
|
if not self._is_fitted:
|
452
443
|
raise exceptions.SnowflakeMLException(
|
@@ -464,9 +455,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
464
455
|
"Session must not specified for snowpark dataset."
|
465
456
|
),
|
466
457
|
)
|
467
|
-
|
468
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
469
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
458
|
+
|
470
459
|
|
471
460
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
472
461
|
@telemetry.send_api_usage_telemetry(
|
@@ -514,7 +503,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
514
503
|
|
515
504
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
516
505
|
|
517
|
-
self.
|
506
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
507
|
+
self._deps = self._get_dependencies()
|
518
508
|
assert isinstance(
|
519
509
|
dataset._session, Session
|
520
510
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -597,10 +587,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
597
587
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
598
588
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
599
589
|
|
600
|
-
self.
|
601
|
-
|
602
|
-
inference_method=inference_method,
|
603
|
-
)
|
590
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
591
|
+
self._deps = self._get_dependencies()
|
604
592
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
605
593
|
|
606
594
|
transform_kwargs = dict(
|
@@ -667,16 +655,40 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
667
655
|
self._is_fitted = True
|
668
656
|
return output_result
|
669
657
|
|
658
|
+
|
659
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
660
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
661
|
+
""" Method not supported for this class.
|
670
662
|
|
671
|
-
|
672
|
-
|
673
|
-
|
663
|
+
|
664
|
+
Raises:
|
665
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
666
|
+
|
667
|
+
Args:
|
668
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
669
|
+
Snowpark or Pandas DataFrame.
|
670
|
+
output_cols_prefix: Prefix for the response columns
|
674
671
|
Returns:
|
675
672
|
Transformed dataset.
|
676
673
|
"""
|
677
|
-
self.
|
678
|
-
|
679
|
-
|
674
|
+
self._infer_input_output_cols(dataset)
|
675
|
+
super()._check_dataset_type(dataset)
|
676
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
677
|
+
estimator=self._sklearn_object,
|
678
|
+
dataset=dataset,
|
679
|
+
input_cols=self.input_cols,
|
680
|
+
label_cols=self.label_cols,
|
681
|
+
sample_weight_col=self.sample_weight_col,
|
682
|
+
autogenerated=self._autogenerated,
|
683
|
+
subproject=_SUBPROJECT,
|
684
|
+
)
|
685
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
686
|
+
drop_input_cols=self._drop_input_cols,
|
687
|
+
expected_output_cols_list=self.output_cols,
|
688
|
+
)
|
689
|
+
self._sklearn_object = fitted_estimator
|
690
|
+
self._is_fitted = True
|
691
|
+
return output_result
|
680
692
|
|
681
693
|
|
682
694
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -769,10 +781,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
769
781
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
770
782
|
|
771
783
|
if isinstance(dataset, DataFrame):
|
772
|
-
self.
|
773
|
-
|
774
|
-
inference_method=inference_method,
|
775
|
-
)
|
784
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
785
|
+
self._deps = self._get_dependencies()
|
776
786
|
assert isinstance(
|
777
787
|
dataset._session, Session
|
778
788
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -839,10 +849,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
839
849
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
840
850
|
|
841
851
|
if isinstance(dataset, DataFrame):
|
842
|
-
self.
|
843
|
-
|
844
|
-
inference_method=inference_method,
|
845
|
-
)
|
852
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
853
|
+
self._deps = self._get_dependencies()
|
846
854
|
assert isinstance(
|
847
855
|
dataset._session, Session
|
848
856
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -906,10 +914,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
906
914
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
907
915
|
|
908
916
|
if isinstance(dataset, DataFrame):
|
909
|
-
self.
|
910
|
-
|
911
|
-
inference_method=inference_method,
|
912
|
-
)
|
917
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
918
|
+
self._deps = self._get_dependencies()
|
913
919
|
assert isinstance(
|
914
920
|
dataset._session, Session
|
915
921
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -975,10 +981,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
975
981
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
976
982
|
|
977
983
|
if isinstance(dataset, DataFrame):
|
978
|
-
self.
|
979
|
-
|
980
|
-
inference_method=inference_method,
|
981
|
-
)
|
984
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
985
|
+
self._deps = self._get_dependencies()
|
982
986
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
983
987
|
transform_kwargs = dict(
|
984
988
|
session=dataset._session,
|
@@ -1042,17 +1046,15 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
1042
1046
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1043
1047
|
|
1044
1048
|
if isinstance(dataset, DataFrame):
|
1045
|
-
self.
|
1046
|
-
|
1047
|
-
inference_method="score",
|
1048
|
-
)
|
1049
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1050
|
+
self._deps = self._get_dependencies()
|
1049
1051
|
selected_cols = self._get_active_columns()
|
1050
1052
|
if len(selected_cols) > 0:
|
1051
1053
|
dataset = dataset.select(selected_cols)
|
1052
1054
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1053
1055
|
transform_kwargs = dict(
|
1054
1056
|
session=dataset._session,
|
1055
|
-
dependencies=
|
1057
|
+
dependencies=self._deps,
|
1056
1058
|
score_sproc_imports=['sklearn'],
|
1057
1059
|
)
|
1058
1060
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1117,11 +1119,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
1117
1119
|
|
1118
1120
|
if isinstance(dataset, DataFrame):
|
1119
1121
|
|
1120
|
-
self.
|
1121
|
-
|
1122
|
-
inference_method=inference_method,
|
1123
|
-
|
1124
|
-
)
|
1122
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1123
|
+
self._deps = self._get_dependencies()
|
1125
1124
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1126
1125
|
transform_kwargs = dict(
|
1127
1126
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class HistGradientBoostingRegressor(BaseTransformer):
|
70
64
|
r"""Histogram-based Gradient Boosting Regression Tree
|
71
65
|
For more details on this class, see [sklearn.ensemble.HistGradientBoostingRegressor]
|
@@ -424,20 +418,17 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
424
418
|
self,
|
425
419
|
dataset: DataFrame,
|
426
420
|
inference_method: str,
|
427
|
-
) ->
|
428
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
429
|
-
return the available package that exists in the snowflake anaconda channel
|
421
|
+
) -> None:
|
422
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
430
423
|
|
431
424
|
Args:
|
432
425
|
dataset: snowpark dataframe
|
433
426
|
inference_method: the inference method such as predict, score...
|
434
|
-
|
427
|
+
|
435
428
|
Raises:
|
436
429
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
437
430
|
SnowflakeMLException: If the session is None, raise error
|
438
431
|
|
439
|
-
Returns:
|
440
|
-
A list of available package that exists in the snowflake anaconda channel
|
441
432
|
"""
|
442
433
|
if not self._is_fitted:
|
443
434
|
raise exceptions.SnowflakeMLException(
|
@@ -455,9 +446,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
455
446
|
"Session must not specified for snowpark dataset."
|
456
447
|
),
|
457
448
|
)
|
458
|
-
|
459
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
460
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
449
|
+
|
461
450
|
|
462
451
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
463
452
|
@telemetry.send_api_usage_telemetry(
|
@@ -505,7 +494,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
505
494
|
|
506
495
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
507
496
|
|
508
|
-
self.
|
497
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
498
|
+
self._deps = self._get_dependencies()
|
509
499
|
assert isinstance(
|
510
500
|
dataset._session, Session
|
511
501
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -588,10 +578,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
588
578
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
589
579
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
590
580
|
|
591
|
-
self.
|
592
|
-
|
593
|
-
inference_method=inference_method,
|
594
|
-
)
|
581
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
582
|
+
self._deps = self._get_dependencies()
|
595
583
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
596
584
|
|
597
585
|
transform_kwargs = dict(
|
@@ -658,16 +646,40 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
658
646
|
self._is_fitted = True
|
659
647
|
return output_result
|
660
648
|
|
649
|
+
|
650
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
651
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
652
|
+
""" Method not supported for this class.
|
661
653
|
|
662
|
-
|
663
|
-
|
664
|
-
|
654
|
+
|
655
|
+
Raises:
|
656
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
657
|
+
|
658
|
+
Args:
|
659
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
660
|
+
Snowpark or Pandas DataFrame.
|
661
|
+
output_cols_prefix: Prefix for the response columns
|
665
662
|
Returns:
|
666
663
|
Transformed dataset.
|
667
664
|
"""
|
668
|
-
self.
|
669
|
-
|
670
|
-
|
665
|
+
self._infer_input_output_cols(dataset)
|
666
|
+
super()._check_dataset_type(dataset)
|
667
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
668
|
+
estimator=self._sklearn_object,
|
669
|
+
dataset=dataset,
|
670
|
+
input_cols=self.input_cols,
|
671
|
+
label_cols=self.label_cols,
|
672
|
+
sample_weight_col=self.sample_weight_col,
|
673
|
+
autogenerated=self._autogenerated,
|
674
|
+
subproject=_SUBPROJECT,
|
675
|
+
)
|
676
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
677
|
+
drop_input_cols=self._drop_input_cols,
|
678
|
+
expected_output_cols_list=self.output_cols,
|
679
|
+
)
|
680
|
+
self._sklearn_object = fitted_estimator
|
681
|
+
self._is_fitted = True
|
682
|
+
return output_result
|
671
683
|
|
672
684
|
|
673
685
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -758,10 +770,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
758
770
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
759
771
|
|
760
772
|
if isinstance(dataset, DataFrame):
|
761
|
-
self.
|
762
|
-
|
763
|
-
inference_method=inference_method,
|
764
|
-
)
|
773
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
774
|
+
self._deps = self._get_dependencies()
|
765
775
|
assert isinstance(
|
766
776
|
dataset._session, Session
|
767
777
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -826,10 +836,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
826
836
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
827
837
|
|
828
838
|
if isinstance(dataset, DataFrame):
|
829
|
-
self.
|
830
|
-
|
831
|
-
inference_method=inference_method,
|
832
|
-
)
|
839
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
840
|
+
self._deps = self._get_dependencies()
|
833
841
|
assert isinstance(
|
834
842
|
dataset._session, Session
|
835
843
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -891,10 +899,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
891
899
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
892
900
|
|
893
901
|
if isinstance(dataset, DataFrame):
|
894
|
-
self.
|
895
|
-
|
896
|
-
inference_method=inference_method,
|
897
|
-
)
|
902
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
903
|
+
self._deps = self._get_dependencies()
|
898
904
|
assert isinstance(
|
899
905
|
dataset._session, Session
|
900
906
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -960,10 +966,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
960
966
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
961
967
|
|
962
968
|
if isinstance(dataset, DataFrame):
|
963
|
-
self.
|
964
|
-
|
965
|
-
inference_method=inference_method,
|
966
|
-
)
|
969
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
970
|
+
self._deps = self._get_dependencies()
|
967
971
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
968
972
|
transform_kwargs = dict(
|
969
973
|
session=dataset._session,
|
@@ -1027,17 +1031,15 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
1027
1031
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1028
1032
|
|
1029
1033
|
if isinstance(dataset, DataFrame):
|
1030
|
-
self.
|
1031
|
-
|
1032
|
-
inference_method="score",
|
1033
|
-
)
|
1034
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1035
|
+
self._deps = self._get_dependencies()
|
1034
1036
|
selected_cols = self._get_active_columns()
|
1035
1037
|
if len(selected_cols) > 0:
|
1036
1038
|
dataset = dataset.select(selected_cols)
|
1037
1039
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1038
1040
|
transform_kwargs = dict(
|
1039
1041
|
session=dataset._session,
|
1040
|
-
dependencies=
|
1042
|
+
dependencies=self._deps,
|
1041
1043
|
score_sproc_imports=['sklearn'],
|
1042
1044
|
)
|
1043
1045
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1102,11 +1104,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
1102
1104
|
|
1103
1105
|
if isinstance(dataset, DataFrame):
|
1104
1106
|
|
1105
|
-
self.
|
1106
|
-
|
1107
|
-
inference_method=inference_method,
|
1108
|
-
|
1109
|
-
)
|
1107
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1108
|
+
self._deps = self._get_dependencies()
|
1110
1109
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1111
1110
|
transform_kwargs = dict(
|
1112
1111
|
session = dataset._session,
|