snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ExtraTreesClassifier(BaseTransformer):
|
70
64
|
r"""An extra-trees classifier
|
71
65
|
For more details on this class, see [sklearn.ensemble.ExtraTreesClassifier]
|
@@ -440,20 +434,17 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
440
434
|
self,
|
441
435
|
dataset: DataFrame,
|
442
436
|
inference_method: str,
|
443
|
-
) ->
|
444
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
445
|
-
return the available package that exists in the snowflake anaconda channel
|
437
|
+
) -> None:
|
438
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
446
439
|
|
447
440
|
Args:
|
448
441
|
dataset: snowpark dataframe
|
449
442
|
inference_method: the inference method such as predict, score...
|
450
|
-
|
443
|
+
|
451
444
|
Raises:
|
452
445
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
453
446
|
SnowflakeMLException: If the session is None, raise error
|
454
447
|
|
455
|
-
Returns:
|
456
|
-
A list of available package that exists in the snowflake anaconda channel
|
457
448
|
"""
|
458
449
|
if not self._is_fitted:
|
459
450
|
raise exceptions.SnowflakeMLException(
|
@@ -471,9 +462,7 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
471
462
|
"Session must not specified for snowpark dataset."
|
472
463
|
),
|
473
464
|
)
|
474
|
-
|
475
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
476
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
465
|
+
|
477
466
|
|
478
467
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
479
468
|
@telemetry.send_api_usage_telemetry(
|
@@ -521,7 +510,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
521
510
|
|
522
511
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
523
512
|
|
524
|
-
self.
|
513
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
514
|
+
self._deps = self._get_dependencies()
|
525
515
|
assert isinstance(
|
526
516
|
dataset._session, Session
|
527
517
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -604,10 +594,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
604
594
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
605
595
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
606
596
|
|
607
|
-
self.
|
608
|
-
|
609
|
-
inference_method=inference_method,
|
610
|
-
)
|
597
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
598
|
+
self._deps = self._get_dependencies()
|
611
599
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
612
600
|
|
613
601
|
transform_kwargs = dict(
|
@@ -674,16 +662,40 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
674
662
|
self._is_fitted = True
|
675
663
|
return output_result
|
676
664
|
|
665
|
+
|
666
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
667
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
668
|
+
""" Method not supported for this class.
|
677
669
|
|
678
|
-
|
679
|
-
|
680
|
-
|
670
|
+
|
671
|
+
Raises:
|
672
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
673
|
+
|
674
|
+
Args:
|
675
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
676
|
+
Snowpark or Pandas DataFrame.
|
677
|
+
output_cols_prefix: Prefix for the response columns
|
681
678
|
Returns:
|
682
679
|
Transformed dataset.
|
683
680
|
"""
|
684
|
-
self.
|
685
|
-
|
686
|
-
|
681
|
+
self._infer_input_output_cols(dataset)
|
682
|
+
super()._check_dataset_type(dataset)
|
683
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
684
|
+
estimator=self._sklearn_object,
|
685
|
+
dataset=dataset,
|
686
|
+
input_cols=self.input_cols,
|
687
|
+
label_cols=self.label_cols,
|
688
|
+
sample_weight_col=self.sample_weight_col,
|
689
|
+
autogenerated=self._autogenerated,
|
690
|
+
subproject=_SUBPROJECT,
|
691
|
+
)
|
692
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
693
|
+
drop_input_cols=self._drop_input_cols,
|
694
|
+
expected_output_cols_list=self.output_cols,
|
695
|
+
)
|
696
|
+
self._sklearn_object = fitted_estimator
|
697
|
+
self._is_fitted = True
|
698
|
+
return output_result
|
687
699
|
|
688
700
|
|
689
701
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -776,10 +788,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
776
788
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
777
789
|
|
778
790
|
if isinstance(dataset, DataFrame):
|
779
|
-
self.
|
780
|
-
|
781
|
-
inference_method=inference_method,
|
782
|
-
)
|
791
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
792
|
+
self._deps = self._get_dependencies()
|
783
793
|
assert isinstance(
|
784
794
|
dataset._session, Session
|
785
795
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -846,10 +856,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
846
856
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
847
857
|
|
848
858
|
if isinstance(dataset, DataFrame):
|
849
|
-
self.
|
850
|
-
|
851
|
-
inference_method=inference_method,
|
852
|
-
)
|
859
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
860
|
+
self._deps = self._get_dependencies()
|
853
861
|
assert isinstance(
|
854
862
|
dataset._session, Session
|
855
863
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -911,10 +919,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
911
919
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
912
920
|
|
913
921
|
if isinstance(dataset, DataFrame):
|
914
|
-
self.
|
915
|
-
|
916
|
-
inference_method=inference_method,
|
917
|
-
)
|
922
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
923
|
+
self._deps = self._get_dependencies()
|
918
924
|
assert isinstance(
|
919
925
|
dataset._session, Session
|
920
926
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -980,10 +986,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
980
986
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
981
987
|
|
982
988
|
if isinstance(dataset, DataFrame):
|
983
|
-
self.
|
984
|
-
|
985
|
-
inference_method=inference_method,
|
986
|
-
)
|
989
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
990
|
+
self._deps = self._get_dependencies()
|
987
991
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
988
992
|
transform_kwargs = dict(
|
989
993
|
session=dataset._session,
|
@@ -1047,17 +1051,15 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
1047
1051
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1048
1052
|
|
1049
1053
|
if isinstance(dataset, DataFrame):
|
1050
|
-
self.
|
1051
|
-
|
1052
|
-
inference_method="score",
|
1053
|
-
)
|
1054
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1055
|
+
self._deps = self._get_dependencies()
|
1054
1056
|
selected_cols = self._get_active_columns()
|
1055
1057
|
if len(selected_cols) > 0:
|
1056
1058
|
dataset = dataset.select(selected_cols)
|
1057
1059
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1058
1060
|
transform_kwargs = dict(
|
1059
1061
|
session=dataset._session,
|
1060
|
-
dependencies=
|
1062
|
+
dependencies=self._deps,
|
1061
1063
|
score_sproc_imports=['sklearn'],
|
1062
1064
|
)
|
1063
1065
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1122,11 +1124,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
1122
1124
|
|
1123
1125
|
if isinstance(dataset, DataFrame):
|
1124
1126
|
|
1125
|
-
self.
|
1126
|
-
|
1127
|
-
inference_method=inference_method,
|
1128
|
-
|
1129
|
-
)
|
1127
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1128
|
+
self._deps = self._get_dependencies()
|
1130
1129
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1131
1130
|
transform_kwargs = dict(
|
1132
1131
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ExtraTreesRegressor(BaseTransformer):
|
70
64
|
r"""An extra-trees regressor
|
71
65
|
For more details on this class, see [sklearn.ensemble.ExtraTreesRegressor]
|
@@ -419,20 +413,17 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
419
413
|
self,
|
420
414
|
dataset: DataFrame,
|
421
415
|
inference_method: str,
|
422
|
-
) ->
|
423
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
424
|
-
return the available package that exists in the snowflake anaconda channel
|
416
|
+
) -> None:
|
417
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
425
418
|
|
426
419
|
Args:
|
427
420
|
dataset: snowpark dataframe
|
428
421
|
inference_method: the inference method such as predict, score...
|
429
|
-
|
422
|
+
|
430
423
|
Raises:
|
431
424
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
432
425
|
SnowflakeMLException: If the session is None, raise error
|
433
426
|
|
434
|
-
Returns:
|
435
|
-
A list of available package that exists in the snowflake anaconda channel
|
436
427
|
"""
|
437
428
|
if not self._is_fitted:
|
438
429
|
raise exceptions.SnowflakeMLException(
|
@@ -450,9 +441,7 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
450
441
|
"Session must not specified for snowpark dataset."
|
451
442
|
),
|
452
443
|
)
|
453
|
-
|
454
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
455
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
444
|
+
|
456
445
|
|
457
446
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
458
447
|
@telemetry.send_api_usage_telemetry(
|
@@ -500,7 +489,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
500
489
|
|
501
490
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
502
491
|
|
503
|
-
self.
|
492
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
493
|
+
self._deps = self._get_dependencies()
|
504
494
|
assert isinstance(
|
505
495
|
dataset._session, Session
|
506
496
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -583,10 +573,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
583
573
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
584
574
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
585
575
|
|
586
|
-
self.
|
587
|
-
|
588
|
-
inference_method=inference_method,
|
589
|
-
)
|
576
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
577
|
+
self._deps = self._get_dependencies()
|
590
578
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
591
579
|
|
592
580
|
transform_kwargs = dict(
|
@@ -653,16 +641,40 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
653
641
|
self._is_fitted = True
|
654
642
|
return output_result
|
655
643
|
|
644
|
+
|
645
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
646
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
647
|
+
""" Method not supported for this class.
|
656
648
|
|
657
|
-
|
658
|
-
|
659
|
-
|
649
|
+
|
650
|
+
Raises:
|
651
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
652
|
+
|
653
|
+
Args:
|
654
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
655
|
+
Snowpark or Pandas DataFrame.
|
656
|
+
output_cols_prefix: Prefix for the response columns
|
660
657
|
Returns:
|
661
658
|
Transformed dataset.
|
662
659
|
"""
|
663
|
-
self.
|
664
|
-
|
665
|
-
|
660
|
+
self._infer_input_output_cols(dataset)
|
661
|
+
super()._check_dataset_type(dataset)
|
662
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
663
|
+
estimator=self._sklearn_object,
|
664
|
+
dataset=dataset,
|
665
|
+
input_cols=self.input_cols,
|
666
|
+
label_cols=self.label_cols,
|
667
|
+
sample_weight_col=self.sample_weight_col,
|
668
|
+
autogenerated=self._autogenerated,
|
669
|
+
subproject=_SUBPROJECT,
|
670
|
+
)
|
671
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
672
|
+
drop_input_cols=self._drop_input_cols,
|
673
|
+
expected_output_cols_list=self.output_cols,
|
674
|
+
)
|
675
|
+
self._sklearn_object = fitted_estimator
|
676
|
+
self._is_fitted = True
|
677
|
+
return output_result
|
666
678
|
|
667
679
|
|
668
680
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -753,10 +765,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
753
765
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
754
766
|
|
755
767
|
if isinstance(dataset, DataFrame):
|
756
|
-
self.
|
757
|
-
|
758
|
-
inference_method=inference_method,
|
759
|
-
)
|
768
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
769
|
+
self._deps = self._get_dependencies()
|
760
770
|
assert isinstance(
|
761
771
|
dataset._session, Session
|
762
772
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -821,10 +831,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
821
831
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
822
832
|
|
823
833
|
if isinstance(dataset, DataFrame):
|
824
|
-
self.
|
825
|
-
|
826
|
-
inference_method=inference_method,
|
827
|
-
)
|
834
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
835
|
+
self._deps = self._get_dependencies()
|
828
836
|
assert isinstance(
|
829
837
|
dataset._session, Session
|
830
838
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -886,10 +894,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
886
894
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
887
895
|
|
888
896
|
if isinstance(dataset, DataFrame):
|
889
|
-
self.
|
890
|
-
|
891
|
-
inference_method=inference_method,
|
892
|
-
)
|
897
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
898
|
+
self._deps = self._get_dependencies()
|
893
899
|
assert isinstance(
|
894
900
|
dataset._session, Session
|
895
901
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -955,10 +961,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
955
961
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
956
962
|
|
957
963
|
if isinstance(dataset, DataFrame):
|
958
|
-
self.
|
959
|
-
|
960
|
-
inference_method=inference_method,
|
961
|
-
)
|
964
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
965
|
+
self._deps = self._get_dependencies()
|
962
966
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
963
967
|
transform_kwargs = dict(
|
964
968
|
session=dataset._session,
|
@@ -1022,17 +1026,15 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
1022
1026
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1023
1027
|
|
1024
1028
|
if isinstance(dataset, DataFrame):
|
1025
|
-
self.
|
1026
|
-
|
1027
|
-
inference_method="score",
|
1028
|
-
)
|
1029
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1030
|
+
self._deps = self._get_dependencies()
|
1029
1031
|
selected_cols = self._get_active_columns()
|
1030
1032
|
if len(selected_cols) > 0:
|
1031
1033
|
dataset = dataset.select(selected_cols)
|
1032
1034
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1033
1035
|
transform_kwargs = dict(
|
1034
1036
|
session=dataset._session,
|
1035
|
-
dependencies=
|
1037
|
+
dependencies=self._deps,
|
1036
1038
|
score_sproc_imports=['sklearn'],
|
1037
1039
|
)
|
1038
1040
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1097,11 +1099,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
1097
1099
|
|
1098
1100
|
if isinstance(dataset, DataFrame):
|
1099
1101
|
|
1100
|
-
self.
|
1101
|
-
|
1102
|
-
inference_method=inference_method,
|
1103
|
-
|
1104
|
-
)
|
1102
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1103
|
+
self._deps = self._get_dependencies()
|
1105
1104
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1106
1105
|
transform_kwargs = dict(
|
1107
1106
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GradientBoostingClassifier(BaseTransformer):
|
70
64
|
r"""Gradient Boosting for classification
|
71
65
|
For more details on this class, see [sklearn.ensemble.GradientBoostingClassifier]
|
@@ -452,20 +446,17 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
452
446
|
self,
|
453
447
|
dataset: DataFrame,
|
454
448
|
inference_method: str,
|
455
|
-
) ->
|
456
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
457
|
-
return the available package that exists in the snowflake anaconda channel
|
449
|
+
) -> None:
|
450
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
458
451
|
|
459
452
|
Args:
|
460
453
|
dataset: snowpark dataframe
|
461
454
|
inference_method: the inference method such as predict, score...
|
462
|
-
|
455
|
+
|
463
456
|
Raises:
|
464
457
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
465
458
|
SnowflakeMLException: If the session is None, raise error
|
466
459
|
|
467
|
-
Returns:
|
468
|
-
A list of available package that exists in the snowflake anaconda channel
|
469
460
|
"""
|
470
461
|
if not self._is_fitted:
|
471
462
|
raise exceptions.SnowflakeMLException(
|
@@ -483,9 +474,7 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
483
474
|
"Session must not specified for snowpark dataset."
|
484
475
|
),
|
485
476
|
)
|
486
|
-
|
487
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
488
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
477
|
+
|
489
478
|
|
490
479
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
491
480
|
@telemetry.send_api_usage_telemetry(
|
@@ -533,7 +522,8 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
533
522
|
|
534
523
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
535
524
|
|
536
|
-
self.
|
525
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
526
|
+
self._deps = self._get_dependencies()
|
537
527
|
assert isinstance(
|
538
528
|
dataset._session, Session
|
539
529
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -616,10 +606,8 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
616
606
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
617
607
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
618
608
|
|
619
|
-
self.
|
620
|
-
|
621
|
-
inference_method=inference_method,
|
622
|
-
)
|
609
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
610
|
+
self._deps = self._get_dependencies()
|
623
611
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
624
612
|
|
625
613
|
transform_kwargs = dict(
|
@@ -686,16 +674,40 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
686
674
|
self._is_fitted = True
|
687
675
|
return output_result
|
688
676
|
|
677
|
+
|
678
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
679
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
680
|
+
""" Method not supported for this class.
|
689
681
|
|
690
|
-
|
691
|
-
|
692
|
-
|
682
|
+
|
683
|
+
Raises:
|
684
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
685
|
+
|
686
|
+
Args:
|
687
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
688
|
+
Snowpark or Pandas DataFrame.
|
689
|
+
output_cols_prefix: Prefix for the response columns
|
693
690
|
Returns:
|
694
691
|
Transformed dataset.
|
695
692
|
"""
|
696
|
-
self.
|
697
|
-
|
698
|
-
|
693
|
+
self._infer_input_output_cols(dataset)
|
694
|
+
super()._check_dataset_type(dataset)
|
695
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
696
|
+
estimator=self._sklearn_object,
|
697
|
+
dataset=dataset,
|
698
|
+
input_cols=self.input_cols,
|
699
|
+
label_cols=self.label_cols,
|
700
|
+
sample_weight_col=self.sample_weight_col,
|
701
|
+
autogenerated=self._autogenerated,
|
702
|
+
subproject=_SUBPROJECT,
|
703
|
+
)
|
704
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
705
|
+
drop_input_cols=self._drop_input_cols,
|
706
|
+
expected_output_cols_list=self.output_cols,
|
707
|
+
)
|
708
|
+
self._sklearn_object = fitted_estimator
|
709
|
+
self._is_fitted = True
|
710
|
+
return output_result
|
699
711
|
|
700
712
|
|
701
713
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -788,10 +800,8 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
788
800
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
789
801
|
|
790
802
|
if isinstance(dataset, DataFrame):
|
791
|
-
self.
|
792
|
-
|
793
|
-
inference_method=inference_method,
|
794
|
-
)
|
803
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
804
|
+
self._deps = self._get_dependencies()
|
795
805
|
assert isinstance(
|
796
806
|
dataset._session, Session
|
797
807
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -858,10 +868,8 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
858
868
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
859
869
|
|
860
870
|
if isinstance(dataset, DataFrame):
|
861
|
-
self.
|
862
|
-
|
863
|
-
inference_method=inference_method,
|
864
|
-
)
|
871
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
872
|
+
self._deps = self._get_dependencies()
|
865
873
|
assert isinstance(
|
866
874
|
dataset._session, Session
|
867
875
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -925,10 +933,8 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
925
933
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
926
934
|
|
927
935
|
if isinstance(dataset, DataFrame):
|
928
|
-
self.
|
929
|
-
|
930
|
-
inference_method=inference_method,
|
931
|
-
)
|
936
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
937
|
+
self._deps = self._get_dependencies()
|
932
938
|
assert isinstance(
|
933
939
|
dataset._session, Session
|
934
940
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -994,10 +1000,8 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
994
1000
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
995
1001
|
|
996
1002
|
if isinstance(dataset, DataFrame):
|
997
|
-
self.
|
998
|
-
|
999
|
-
inference_method=inference_method,
|
1000
|
-
)
|
1003
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1004
|
+
self._deps = self._get_dependencies()
|
1001
1005
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1002
1006
|
transform_kwargs = dict(
|
1003
1007
|
session=dataset._session,
|
@@ -1061,17 +1065,15 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
1061
1065
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1062
1066
|
|
1063
1067
|
if isinstance(dataset, DataFrame):
|
1064
|
-
self.
|
1065
|
-
|
1066
|
-
inference_method="score",
|
1067
|
-
)
|
1068
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1069
|
+
self._deps = self._get_dependencies()
|
1068
1070
|
selected_cols = self._get_active_columns()
|
1069
1071
|
if len(selected_cols) > 0:
|
1070
1072
|
dataset = dataset.select(selected_cols)
|
1071
1073
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1072
1074
|
transform_kwargs = dict(
|
1073
1075
|
session=dataset._session,
|
1074
|
-
dependencies=
|
1076
|
+
dependencies=self._deps,
|
1075
1077
|
score_sproc_imports=['sklearn'],
|
1076
1078
|
)
|
1077
1079
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1136,11 +1138,8 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
1136
1138
|
|
1137
1139
|
if isinstance(dataset, DataFrame):
|
1138
1140
|
|
1139
|
-
self.
|
1140
|
-
|
1141
|
-
inference_method=inference_method,
|
1142
|
-
|
1143
|
-
)
|
1141
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1142
|
+
self._deps = self._get_dependencies()
|
1144
1143
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1145
1144
|
transform_kwargs = dict(
|
1146
1145
|
session = dataset._session,
|