snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn."
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ExtraTreeRegressor(BaseTransformer):
|
70
64
|
r"""An extremely randomized tree regressor
|
71
65
|
For more details on this class, see [sklearn.tree.ExtraTreeRegressor]
|
@@ -365,20 +359,17 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
365
359
|
self,
|
366
360
|
dataset: DataFrame,
|
367
361
|
inference_method: str,
|
368
|
-
) ->
|
369
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
370
|
-
return the available package that exists in the snowflake anaconda channel
|
362
|
+
) -> None:
|
363
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
371
364
|
|
372
365
|
Args:
|
373
366
|
dataset: snowpark dataframe
|
374
367
|
inference_method: the inference method such as predict, score...
|
375
|
-
|
368
|
+
|
376
369
|
Raises:
|
377
370
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
378
371
|
SnowflakeMLException: If the session is None, raise error
|
379
372
|
|
380
|
-
Returns:
|
381
|
-
A list of available package that exists in the snowflake anaconda channel
|
382
373
|
"""
|
383
374
|
if not self._is_fitted:
|
384
375
|
raise exceptions.SnowflakeMLException(
|
@@ -396,9 +387,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
396
387
|
"Session must not specified for snowpark dataset."
|
397
388
|
),
|
398
389
|
)
|
399
|
-
|
400
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
401
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
390
|
+
|
402
391
|
|
403
392
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
404
393
|
@telemetry.send_api_usage_telemetry(
|
@@ -446,7 +435,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
446
435
|
|
447
436
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
448
437
|
|
449
|
-
self.
|
438
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
439
|
+
self._deps = self._get_dependencies()
|
450
440
|
assert isinstance(
|
451
441
|
dataset._session, Session
|
452
442
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -529,10 +519,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
529
519
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
530
520
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
531
521
|
|
532
|
-
self.
|
533
|
-
|
534
|
-
inference_method=inference_method,
|
535
|
-
)
|
522
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
523
|
+
self._deps = self._get_dependencies()
|
536
524
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
537
525
|
|
538
526
|
transform_kwargs = dict(
|
@@ -599,16 +587,40 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
599
587
|
self._is_fitted = True
|
600
588
|
return output_result
|
601
589
|
|
590
|
+
|
591
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
592
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
593
|
+
""" Method not supported for this class.
|
602
594
|
|
603
|
-
|
604
|
-
|
605
|
-
|
595
|
+
|
596
|
+
Raises:
|
597
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
598
|
+
|
599
|
+
Args:
|
600
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
601
|
+
Snowpark or Pandas DataFrame.
|
602
|
+
output_cols_prefix: Prefix for the response columns
|
606
603
|
Returns:
|
607
604
|
Transformed dataset.
|
608
605
|
"""
|
609
|
-
self.
|
610
|
-
|
611
|
-
|
606
|
+
self._infer_input_output_cols(dataset)
|
607
|
+
super()._check_dataset_type(dataset)
|
608
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
609
|
+
estimator=self._sklearn_object,
|
610
|
+
dataset=dataset,
|
611
|
+
input_cols=self.input_cols,
|
612
|
+
label_cols=self.label_cols,
|
613
|
+
sample_weight_col=self.sample_weight_col,
|
614
|
+
autogenerated=self._autogenerated,
|
615
|
+
subproject=_SUBPROJECT,
|
616
|
+
)
|
617
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
618
|
+
drop_input_cols=self._drop_input_cols,
|
619
|
+
expected_output_cols_list=self.output_cols,
|
620
|
+
)
|
621
|
+
self._sklearn_object = fitted_estimator
|
622
|
+
self._is_fitted = True
|
623
|
+
return output_result
|
612
624
|
|
613
625
|
|
614
626
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -699,10 +711,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
699
711
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
700
712
|
|
701
713
|
if isinstance(dataset, DataFrame):
|
702
|
-
self.
|
703
|
-
|
704
|
-
inference_method=inference_method,
|
705
|
-
)
|
714
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
715
|
+
self._deps = self._get_dependencies()
|
706
716
|
assert isinstance(
|
707
717
|
dataset._session, Session
|
708
718
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -767,10 +777,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
767
777
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
768
778
|
|
769
779
|
if isinstance(dataset, DataFrame):
|
770
|
-
self.
|
771
|
-
|
772
|
-
inference_method=inference_method,
|
773
|
-
)
|
780
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
781
|
+
self._deps = self._get_dependencies()
|
774
782
|
assert isinstance(
|
775
783
|
dataset._session, Session
|
776
784
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -832,10 +840,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
832
840
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
833
841
|
|
834
842
|
if isinstance(dataset, DataFrame):
|
835
|
-
self.
|
836
|
-
|
837
|
-
inference_method=inference_method,
|
838
|
-
)
|
843
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
844
|
+
self._deps = self._get_dependencies()
|
839
845
|
assert isinstance(
|
840
846
|
dataset._session, Session
|
841
847
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -901,10 +907,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
901
907
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
902
908
|
|
903
909
|
if isinstance(dataset, DataFrame):
|
904
|
-
self.
|
905
|
-
|
906
|
-
inference_method=inference_method,
|
907
|
-
)
|
910
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
911
|
+
self._deps = self._get_dependencies()
|
908
912
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
909
913
|
transform_kwargs = dict(
|
910
914
|
session=dataset._session,
|
@@ -968,17 +972,15 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
968
972
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
969
973
|
|
970
974
|
if isinstance(dataset, DataFrame):
|
971
|
-
self.
|
972
|
-
|
973
|
-
inference_method="score",
|
974
|
-
)
|
975
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
976
|
+
self._deps = self._get_dependencies()
|
975
977
|
selected_cols = self._get_active_columns()
|
976
978
|
if len(selected_cols) > 0:
|
977
979
|
dataset = dataset.select(selected_cols)
|
978
980
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
979
981
|
transform_kwargs = dict(
|
980
982
|
session=dataset._session,
|
981
|
-
dependencies=
|
983
|
+
dependencies=self._deps,
|
982
984
|
score_sproc_imports=['sklearn'],
|
983
985
|
)
|
984
986
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1043,11 +1045,8 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
1043
1045
|
|
1044
1046
|
if isinstance(dataset, DataFrame):
|
1045
1047
|
|
1046
|
-
self.
|
1047
|
-
|
1048
|
-
inference_method=inference_method,
|
1049
|
-
|
1050
|
-
)
|
1048
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1049
|
+
self._deps = self._get_dependencies()
|
1051
1050
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1052
1051
|
transform_kwargs = dict(
|
1053
1052
|
session = dataset._session,
|
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
|
|
59
59
|
|
60
60
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
61
61
|
|
62
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
63
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
64
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
65
|
-
return check
|
66
|
-
|
67
|
-
|
68
62
|
class XGBClassifier(BaseTransformer):
|
69
63
|
r"""Implementation of the scikit-learn API for XGBoost classification
|
70
64
|
For more details on this class, see [xgboost.XGBClassifier]
|
@@ -483,20 +477,17 @@ class XGBClassifier(BaseTransformer):
|
|
483
477
|
self,
|
484
478
|
dataset: DataFrame,
|
485
479
|
inference_method: str,
|
486
|
-
) ->
|
487
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
488
|
-
return the available package that exists in the snowflake anaconda channel
|
480
|
+
) -> None:
|
481
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
489
482
|
|
490
483
|
Args:
|
491
484
|
dataset: snowpark dataframe
|
492
485
|
inference_method: the inference method such as predict, score...
|
493
|
-
|
486
|
+
|
494
487
|
Raises:
|
495
488
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
496
489
|
SnowflakeMLException: If the session is None, raise error
|
497
490
|
|
498
|
-
Returns:
|
499
|
-
A list of available package that exists in the snowflake anaconda channel
|
500
491
|
"""
|
501
492
|
if not self._is_fitted:
|
502
493
|
raise exceptions.SnowflakeMLException(
|
@@ -514,9 +505,7 @@ class XGBClassifier(BaseTransformer):
|
|
514
505
|
"Session must not specified for snowpark dataset."
|
515
506
|
),
|
516
507
|
)
|
517
|
-
|
518
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
519
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
508
|
+
|
520
509
|
|
521
510
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
522
511
|
@telemetry.send_api_usage_telemetry(
|
@@ -564,7 +553,8 @@ class XGBClassifier(BaseTransformer):
|
|
564
553
|
|
565
554
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
566
555
|
|
567
|
-
self.
|
556
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
557
|
+
self._deps = self._get_dependencies()
|
568
558
|
assert isinstance(
|
569
559
|
dataset._session, Session
|
570
560
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -647,10 +637,8 @@ class XGBClassifier(BaseTransformer):
|
|
647
637
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
648
638
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
649
639
|
|
650
|
-
self.
|
651
|
-
|
652
|
-
inference_method=inference_method,
|
653
|
-
)
|
640
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
641
|
+
self._deps = self._get_dependencies()
|
654
642
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
655
643
|
|
656
644
|
transform_kwargs = dict(
|
@@ -717,16 +705,40 @@ class XGBClassifier(BaseTransformer):
|
|
717
705
|
self._is_fitted = True
|
718
706
|
return output_result
|
719
707
|
|
708
|
+
|
709
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
710
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
711
|
+
""" Method not supported for this class.
|
720
712
|
|
721
|
-
|
722
|
-
|
723
|
-
|
713
|
+
|
714
|
+
Raises:
|
715
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
716
|
+
|
717
|
+
Args:
|
718
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
719
|
+
Snowpark or Pandas DataFrame.
|
720
|
+
output_cols_prefix: Prefix for the response columns
|
724
721
|
Returns:
|
725
722
|
Transformed dataset.
|
726
723
|
"""
|
727
|
-
self.
|
728
|
-
|
729
|
-
|
724
|
+
self._infer_input_output_cols(dataset)
|
725
|
+
super()._check_dataset_type(dataset)
|
726
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
727
|
+
estimator=self._sklearn_object,
|
728
|
+
dataset=dataset,
|
729
|
+
input_cols=self.input_cols,
|
730
|
+
label_cols=self.label_cols,
|
731
|
+
sample_weight_col=self.sample_weight_col,
|
732
|
+
autogenerated=self._autogenerated,
|
733
|
+
subproject=_SUBPROJECT,
|
734
|
+
)
|
735
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
736
|
+
drop_input_cols=self._drop_input_cols,
|
737
|
+
expected_output_cols_list=self.output_cols,
|
738
|
+
)
|
739
|
+
self._sklearn_object = fitted_estimator
|
740
|
+
self._is_fitted = True
|
741
|
+
return output_result
|
730
742
|
|
731
743
|
|
732
744
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -819,10 +831,8 @@ class XGBClassifier(BaseTransformer):
|
|
819
831
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
820
832
|
|
821
833
|
if isinstance(dataset, DataFrame):
|
822
|
-
self.
|
823
|
-
|
824
|
-
inference_method=inference_method,
|
825
|
-
)
|
834
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
835
|
+
self._deps = self._get_dependencies()
|
826
836
|
assert isinstance(
|
827
837
|
dataset._session, Session
|
828
838
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -889,10 +899,8 @@ class XGBClassifier(BaseTransformer):
|
|
889
899
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
890
900
|
|
891
901
|
if isinstance(dataset, DataFrame):
|
892
|
-
self.
|
893
|
-
|
894
|
-
inference_method=inference_method,
|
895
|
-
)
|
902
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
903
|
+
self._deps = self._get_dependencies()
|
896
904
|
assert isinstance(
|
897
905
|
dataset._session, Session
|
898
906
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -954,10 +962,8 @@ class XGBClassifier(BaseTransformer):
|
|
954
962
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
955
963
|
|
956
964
|
if isinstance(dataset, DataFrame):
|
957
|
-
self.
|
958
|
-
|
959
|
-
inference_method=inference_method,
|
960
|
-
)
|
965
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
966
|
+
self._deps = self._get_dependencies()
|
961
967
|
assert isinstance(
|
962
968
|
dataset._session, Session
|
963
969
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -1023,10 +1029,8 @@ class XGBClassifier(BaseTransformer):
|
|
1023
1029
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
1024
1030
|
|
1025
1031
|
if isinstance(dataset, DataFrame):
|
1026
|
-
self.
|
1027
|
-
|
1028
|
-
inference_method=inference_method,
|
1029
|
-
)
|
1032
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1033
|
+
self._deps = self._get_dependencies()
|
1030
1034
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1031
1035
|
transform_kwargs = dict(
|
1032
1036
|
session=dataset._session,
|
@@ -1090,17 +1094,15 @@ class XGBClassifier(BaseTransformer):
|
|
1090
1094
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1091
1095
|
|
1092
1096
|
if isinstance(dataset, DataFrame):
|
1093
|
-
self.
|
1094
|
-
|
1095
|
-
inference_method="score",
|
1096
|
-
)
|
1097
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1098
|
+
self._deps = self._get_dependencies()
|
1097
1099
|
selected_cols = self._get_active_columns()
|
1098
1100
|
if len(selected_cols) > 0:
|
1099
1101
|
dataset = dataset.select(selected_cols)
|
1100
1102
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1101
1103
|
transform_kwargs = dict(
|
1102
1104
|
session=dataset._session,
|
1103
|
-
dependencies=
|
1105
|
+
dependencies=self._deps,
|
1104
1106
|
score_sproc_imports=['xgboost'],
|
1105
1107
|
)
|
1106
1108
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1165,11 +1167,8 @@ class XGBClassifier(BaseTransformer):
|
|
1165
1167
|
|
1166
1168
|
if isinstance(dataset, DataFrame):
|
1167
1169
|
|
1168
|
-
self.
|
1169
|
-
|
1170
|
-
inference_method=inference_method,
|
1171
|
-
|
1172
|
-
)
|
1170
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1171
|
+
self._deps = self._get_dependencies()
|
1173
1172
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1174
1173
|
transform_kwargs = dict(
|
1175
1174
|
session = dataset._session,
|
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
|
|
59
59
|
|
60
60
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
61
61
|
|
62
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
63
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
64
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
65
|
-
return check
|
66
|
-
|
67
|
-
|
68
62
|
class XGBRegressor(BaseTransformer):
|
69
63
|
r"""Implementation of the scikit-learn API for XGBoost regression
|
70
64
|
For more details on this class, see [xgboost.XGBRegressor]
|
@@ -482,20 +476,17 @@ class XGBRegressor(BaseTransformer):
|
|
482
476
|
self,
|
483
477
|
dataset: DataFrame,
|
484
478
|
inference_method: str,
|
485
|
-
) ->
|
486
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
487
|
-
return the available package that exists in the snowflake anaconda channel
|
479
|
+
) -> None:
|
480
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
488
481
|
|
489
482
|
Args:
|
490
483
|
dataset: snowpark dataframe
|
491
484
|
inference_method: the inference method such as predict, score...
|
492
|
-
|
485
|
+
|
493
486
|
Raises:
|
494
487
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
495
488
|
SnowflakeMLException: If the session is None, raise error
|
496
489
|
|
497
|
-
Returns:
|
498
|
-
A list of available package that exists in the snowflake anaconda channel
|
499
490
|
"""
|
500
491
|
if not self._is_fitted:
|
501
492
|
raise exceptions.SnowflakeMLException(
|
@@ -513,9 +504,7 @@ class XGBRegressor(BaseTransformer):
|
|
513
504
|
"Session must not specified for snowpark dataset."
|
514
505
|
),
|
515
506
|
)
|
516
|
-
|
517
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
518
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
507
|
+
|
519
508
|
|
520
509
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
521
510
|
@telemetry.send_api_usage_telemetry(
|
@@ -563,7 +552,8 @@ class XGBRegressor(BaseTransformer):
|
|
563
552
|
|
564
553
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
565
554
|
|
566
|
-
self.
|
555
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
556
|
+
self._deps = self._get_dependencies()
|
567
557
|
assert isinstance(
|
568
558
|
dataset._session, Session
|
569
559
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -646,10 +636,8 @@ class XGBRegressor(BaseTransformer):
|
|
646
636
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
647
637
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
648
638
|
|
649
|
-
self.
|
650
|
-
|
651
|
-
inference_method=inference_method,
|
652
|
-
)
|
639
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
640
|
+
self._deps = self._get_dependencies()
|
653
641
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
654
642
|
|
655
643
|
transform_kwargs = dict(
|
@@ -716,16 +704,40 @@ class XGBRegressor(BaseTransformer):
|
|
716
704
|
self._is_fitted = True
|
717
705
|
return output_result
|
718
706
|
|
707
|
+
|
708
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
709
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
710
|
+
""" Method not supported for this class.
|
719
711
|
|
720
|
-
|
721
|
-
|
722
|
-
|
712
|
+
|
713
|
+
Raises:
|
714
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
715
|
+
|
716
|
+
Args:
|
717
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
718
|
+
Snowpark or Pandas DataFrame.
|
719
|
+
output_cols_prefix: Prefix for the response columns
|
723
720
|
Returns:
|
724
721
|
Transformed dataset.
|
725
722
|
"""
|
726
|
-
self.
|
727
|
-
|
728
|
-
|
723
|
+
self._infer_input_output_cols(dataset)
|
724
|
+
super()._check_dataset_type(dataset)
|
725
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
726
|
+
estimator=self._sklearn_object,
|
727
|
+
dataset=dataset,
|
728
|
+
input_cols=self.input_cols,
|
729
|
+
label_cols=self.label_cols,
|
730
|
+
sample_weight_col=self.sample_weight_col,
|
731
|
+
autogenerated=self._autogenerated,
|
732
|
+
subproject=_SUBPROJECT,
|
733
|
+
)
|
734
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
735
|
+
drop_input_cols=self._drop_input_cols,
|
736
|
+
expected_output_cols_list=self.output_cols,
|
737
|
+
)
|
738
|
+
self._sklearn_object = fitted_estimator
|
739
|
+
self._is_fitted = True
|
740
|
+
return output_result
|
729
741
|
|
730
742
|
|
731
743
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -816,10 +828,8 @@ class XGBRegressor(BaseTransformer):
|
|
816
828
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
817
829
|
|
818
830
|
if isinstance(dataset, DataFrame):
|
819
|
-
self.
|
820
|
-
|
821
|
-
inference_method=inference_method,
|
822
|
-
)
|
831
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
832
|
+
self._deps = self._get_dependencies()
|
823
833
|
assert isinstance(
|
824
834
|
dataset._session, Session
|
825
835
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -884,10 +894,8 @@ class XGBRegressor(BaseTransformer):
|
|
884
894
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
885
895
|
|
886
896
|
if isinstance(dataset, DataFrame):
|
887
|
-
self.
|
888
|
-
|
889
|
-
inference_method=inference_method,
|
890
|
-
)
|
897
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
898
|
+
self._deps = self._get_dependencies()
|
891
899
|
assert isinstance(
|
892
900
|
dataset._session, Session
|
893
901
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -949,10 +957,8 @@ class XGBRegressor(BaseTransformer):
|
|
949
957
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
950
958
|
|
951
959
|
if isinstance(dataset, DataFrame):
|
952
|
-
self.
|
953
|
-
|
954
|
-
inference_method=inference_method,
|
955
|
-
)
|
960
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
961
|
+
self._deps = self._get_dependencies()
|
956
962
|
assert isinstance(
|
957
963
|
dataset._session, Session
|
958
964
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -1018,10 +1024,8 @@ class XGBRegressor(BaseTransformer):
|
|
1018
1024
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
1019
1025
|
|
1020
1026
|
if isinstance(dataset, DataFrame):
|
1021
|
-
self.
|
1022
|
-
|
1023
|
-
inference_method=inference_method,
|
1024
|
-
)
|
1027
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1028
|
+
self._deps = self._get_dependencies()
|
1025
1029
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1026
1030
|
transform_kwargs = dict(
|
1027
1031
|
session=dataset._session,
|
@@ -1085,17 +1089,15 @@ class XGBRegressor(BaseTransformer):
|
|
1085
1089
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1086
1090
|
|
1087
1091
|
if isinstance(dataset, DataFrame):
|
1088
|
-
self.
|
1089
|
-
|
1090
|
-
inference_method="score",
|
1091
|
-
)
|
1092
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1093
|
+
self._deps = self._get_dependencies()
|
1092
1094
|
selected_cols = self._get_active_columns()
|
1093
1095
|
if len(selected_cols) > 0:
|
1094
1096
|
dataset = dataset.select(selected_cols)
|
1095
1097
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1096
1098
|
transform_kwargs = dict(
|
1097
1099
|
session=dataset._session,
|
1098
|
-
dependencies=
|
1100
|
+
dependencies=self._deps,
|
1099
1101
|
score_sproc_imports=['xgboost'],
|
1100
1102
|
)
|
1101
1103
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1160,11 +1162,8 @@ class XGBRegressor(BaseTransformer):
|
|
1160
1162
|
|
1161
1163
|
if isinstance(dataset, DataFrame):
|
1162
1164
|
|
1163
|
-
self.
|
1164
|
-
|
1165
|
-
inference_method=inference_method,
|
1166
|
-
|
1167
|
-
)
|
1165
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1166
|
+
self._deps = self._get_dependencies()
|
1168
1167
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1169
1168
|
transform_kwargs = dict(
|
1170
1169
|
session = dataset._session,
|