snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn."
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class DecisionTreeClassifier(BaseTransformer):
|
70
64
|
r"""A decision tree classifier
|
71
65
|
For more details on this class, see [sklearn.tree.DecisionTreeClassifier]
|
@@ -391,20 +385,17 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
391
385
|
self,
|
392
386
|
dataset: DataFrame,
|
393
387
|
inference_method: str,
|
394
|
-
) ->
|
395
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
396
|
-
return the available package that exists in the snowflake anaconda channel
|
388
|
+
) -> None:
|
389
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
397
390
|
|
398
391
|
Args:
|
399
392
|
dataset: snowpark dataframe
|
400
393
|
inference_method: the inference method such as predict, score...
|
401
|
-
|
394
|
+
|
402
395
|
Raises:
|
403
396
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
404
397
|
SnowflakeMLException: If the session is None, raise error
|
405
398
|
|
406
|
-
Returns:
|
407
|
-
A list of available package that exists in the snowflake anaconda channel
|
408
399
|
"""
|
409
400
|
if not self._is_fitted:
|
410
401
|
raise exceptions.SnowflakeMLException(
|
@@ -422,9 +413,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
422
413
|
"Session must not specified for snowpark dataset."
|
423
414
|
),
|
424
415
|
)
|
425
|
-
|
426
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
427
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
416
|
+
|
428
417
|
|
429
418
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
430
419
|
@telemetry.send_api_usage_telemetry(
|
@@ -472,7 +461,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
472
461
|
|
473
462
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
474
463
|
|
475
|
-
self.
|
464
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
465
|
+
self._deps = self._get_dependencies()
|
476
466
|
assert isinstance(
|
477
467
|
dataset._session, Session
|
478
468
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -555,10 +545,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
555
545
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
556
546
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
557
547
|
|
558
|
-
self.
|
559
|
-
|
560
|
-
inference_method=inference_method,
|
561
|
-
)
|
548
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
549
|
+
self._deps = self._get_dependencies()
|
562
550
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
563
551
|
|
564
552
|
transform_kwargs = dict(
|
@@ -625,16 +613,40 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
625
613
|
self._is_fitted = True
|
626
614
|
return output_result
|
627
615
|
|
616
|
+
|
617
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
618
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
619
|
+
""" Method not supported for this class.
|
628
620
|
|
629
|
-
|
630
|
-
|
631
|
-
|
621
|
+
|
622
|
+
Raises:
|
623
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
624
|
+
|
625
|
+
Args:
|
626
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
627
|
+
Snowpark or Pandas DataFrame.
|
628
|
+
output_cols_prefix: Prefix for the response columns
|
632
629
|
Returns:
|
633
630
|
Transformed dataset.
|
634
631
|
"""
|
635
|
-
self.
|
636
|
-
|
637
|
-
|
632
|
+
self._infer_input_output_cols(dataset)
|
633
|
+
super()._check_dataset_type(dataset)
|
634
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
635
|
+
estimator=self._sklearn_object,
|
636
|
+
dataset=dataset,
|
637
|
+
input_cols=self.input_cols,
|
638
|
+
label_cols=self.label_cols,
|
639
|
+
sample_weight_col=self.sample_weight_col,
|
640
|
+
autogenerated=self._autogenerated,
|
641
|
+
subproject=_SUBPROJECT,
|
642
|
+
)
|
643
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
644
|
+
drop_input_cols=self._drop_input_cols,
|
645
|
+
expected_output_cols_list=self.output_cols,
|
646
|
+
)
|
647
|
+
self._sklearn_object = fitted_estimator
|
648
|
+
self._is_fitted = True
|
649
|
+
return output_result
|
638
650
|
|
639
651
|
|
640
652
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -727,10 +739,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
727
739
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
728
740
|
|
729
741
|
if isinstance(dataset, DataFrame):
|
730
|
-
self.
|
731
|
-
|
732
|
-
inference_method=inference_method,
|
733
|
-
)
|
742
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
743
|
+
self._deps = self._get_dependencies()
|
734
744
|
assert isinstance(
|
735
745
|
dataset._session, Session
|
736
746
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -797,10 +807,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
797
807
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
798
808
|
|
799
809
|
if isinstance(dataset, DataFrame):
|
800
|
-
self.
|
801
|
-
|
802
|
-
inference_method=inference_method,
|
803
|
-
)
|
810
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
811
|
+
self._deps = self._get_dependencies()
|
804
812
|
assert isinstance(
|
805
813
|
dataset._session, Session
|
806
814
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -862,10 +870,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
862
870
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
863
871
|
|
864
872
|
if isinstance(dataset, DataFrame):
|
865
|
-
self.
|
866
|
-
|
867
|
-
inference_method=inference_method,
|
868
|
-
)
|
873
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
874
|
+
self._deps = self._get_dependencies()
|
869
875
|
assert isinstance(
|
870
876
|
dataset._session, Session
|
871
877
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -931,10 +937,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
931
937
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
932
938
|
|
933
939
|
if isinstance(dataset, DataFrame):
|
934
|
-
self.
|
935
|
-
|
936
|
-
inference_method=inference_method,
|
937
|
-
)
|
940
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
941
|
+
self._deps = self._get_dependencies()
|
938
942
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
939
943
|
transform_kwargs = dict(
|
940
944
|
session=dataset._session,
|
@@ -998,17 +1002,15 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
998
1002
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
999
1003
|
|
1000
1004
|
if isinstance(dataset, DataFrame):
|
1001
|
-
self.
|
1002
|
-
|
1003
|
-
inference_method="score",
|
1004
|
-
)
|
1005
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1006
|
+
self._deps = self._get_dependencies()
|
1005
1007
|
selected_cols = self._get_active_columns()
|
1006
1008
|
if len(selected_cols) > 0:
|
1007
1009
|
dataset = dataset.select(selected_cols)
|
1008
1010
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1009
1011
|
transform_kwargs = dict(
|
1010
1012
|
session=dataset._session,
|
1011
|
-
dependencies=
|
1013
|
+
dependencies=self._deps,
|
1012
1014
|
score_sproc_imports=['sklearn'],
|
1013
1015
|
)
|
1014
1016
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1073,11 +1075,8 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
1073
1075
|
|
1074
1076
|
if isinstance(dataset, DataFrame):
|
1075
1077
|
|
1076
|
-
self.
|
1077
|
-
|
1078
|
-
inference_method=inference_method,
|
1079
|
-
|
1080
|
-
)
|
1078
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1079
|
+
self._deps = self._get_dependencies()
|
1081
1080
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1082
1081
|
transform_kwargs = dict(
|
1083
1082
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn."
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class DecisionTreeRegressor(BaseTransformer):
|
70
64
|
r"""A decision tree regressor
|
71
65
|
For more details on this class, see [sklearn.tree.DecisionTreeRegressor]
|
@@ -373,20 +367,17 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
373
367
|
self,
|
374
368
|
dataset: DataFrame,
|
375
369
|
inference_method: str,
|
376
|
-
) ->
|
377
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
378
|
-
return the available package that exists in the snowflake anaconda channel
|
370
|
+
) -> None:
|
371
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
379
372
|
|
380
373
|
Args:
|
381
374
|
dataset: snowpark dataframe
|
382
375
|
inference_method: the inference method such as predict, score...
|
383
|
-
|
376
|
+
|
384
377
|
Raises:
|
385
378
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
386
379
|
SnowflakeMLException: If the session is None, raise error
|
387
380
|
|
388
|
-
Returns:
|
389
|
-
A list of available package that exists in the snowflake anaconda channel
|
390
381
|
"""
|
391
382
|
if not self._is_fitted:
|
392
383
|
raise exceptions.SnowflakeMLException(
|
@@ -404,9 +395,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
404
395
|
"Session must not specified for snowpark dataset."
|
405
396
|
),
|
406
397
|
)
|
407
|
-
|
408
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
409
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
398
|
+
|
410
399
|
|
411
400
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
412
401
|
@telemetry.send_api_usage_telemetry(
|
@@ -454,7 +443,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
454
443
|
|
455
444
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
456
445
|
|
457
|
-
self.
|
446
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
447
|
+
self._deps = self._get_dependencies()
|
458
448
|
assert isinstance(
|
459
449
|
dataset._session, Session
|
460
450
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -537,10 +527,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
537
527
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
538
528
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
539
529
|
|
540
|
-
self.
|
541
|
-
|
542
|
-
inference_method=inference_method,
|
543
|
-
)
|
530
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
531
|
+
self._deps = self._get_dependencies()
|
544
532
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
545
533
|
|
546
534
|
transform_kwargs = dict(
|
@@ -607,16 +595,40 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
607
595
|
self._is_fitted = True
|
608
596
|
return output_result
|
609
597
|
|
598
|
+
|
599
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
600
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
601
|
+
""" Method not supported for this class.
|
610
602
|
|
611
|
-
|
612
|
-
|
613
|
-
|
603
|
+
|
604
|
+
Raises:
|
605
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
606
|
+
|
607
|
+
Args:
|
608
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
609
|
+
Snowpark or Pandas DataFrame.
|
610
|
+
output_cols_prefix: Prefix for the response columns
|
614
611
|
Returns:
|
615
612
|
Transformed dataset.
|
616
613
|
"""
|
617
|
-
self.
|
618
|
-
|
619
|
-
|
614
|
+
self._infer_input_output_cols(dataset)
|
615
|
+
super()._check_dataset_type(dataset)
|
616
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
617
|
+
estimator=self._sklearn_object,
|
618
|
+
dataset=dataset,
|
619
|
+
input_cols=self.input_cols,
|
620
|
+
label_cols=self.label_cols,
|
621
|
+
sample_weight_col=self.sample_weight_col,
|
622
|
+
autogenerated=self._autogenerated,
|
623
|
+
subproject=_SUBPROJECT,
|
624
|
+
)
|
625
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
626
|
+
drop_input_cols=self._drop_input_cols,
|
627
|
+
expected_output_cols_list=self.output_cols,
|
628
|
+
)
|
629
|
+
self._sklearn_object = fitted_estimator
|
630
|
+
self._is_fitted = True
|
631
|
+
return output_result
|
620
632
|
|
621
633
|
|
622
634
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -707,10 +719,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
707
719
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
708
720
|
|
709
721
|
if isinstance(dataset, DataFrame):
|
710
|
-
self.
|
711
|
-
|
712
|
-
inference_method=inference_method,
|
713
|
-
)
|
722
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
723
|
+
self._deps = self._get_dependencies()
|
714
724
|
assert isinstance(
|
715
725
|
dataset._session, Session
|
716
726
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -775,10 +785,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
775
785
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
776
786
|
|
777
787
|
if isinstance(dataset, DataFrame):
|
778
|
-
self.
|
779
|
-
|
780
|
-
inference_method=inference_method,
|
781
|
-
)
|
788
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
789
|
+
self._deps = self._get_dependencies()
|
782
790
|
assert isinstance(
|
783
791
|
dataset._session, Session
|
784
792
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -840,10 +848,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
840
848
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
841
849
|
|
842
850
|
if isinstance(dataset, DataFrame):
|
843
|
-
self.
|
844
|
-
|
845
|
-
inference_method=inference_method,
|
846
|
-
)
|
851
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
852
|
+
self._deps = self._get_dependencies()
|
847
853
|
assert isinstance(
|
848
854
|
dataset._session, Session
|
849
855
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -909,10 +915,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
909
915
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
910
916
|
|
911
917
|
if isinstance(dataset, DataFrame):
|
912
|
-
self.
|
913
|
-
|
914
|
-
inference_method=inference_method,
|
915
|
-
)
|
918
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
919
|
+
self._deps = self._get_dependencies()
|
916
920
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
917
921
|
transform_kwargs = dict(
|
918
922
|
session=dataset._session,
|
@@ -976,17 +980,15 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
976
980
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
977
981
|
|
978
982
|
if isinstance(dataset, DataFrame):
|
979
|
-
self.
|
980
|
-
|
981
|
-
inference_method="score",
|
982
|
-
)
|
983
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
984
|
+
self._deps = self._get_dependencies()
|
983
985
|
selected_cols = self._get_active_columns()
|
984
986
|
if len(selected_cols) > 0:
|
985
987
|
dataset = dataset.select(selected_cols)
|
986
988
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
987
989
|
transform_kwargs = dict(
|
988
990
|
session=dataset._session,
|
989
|
-
dependencies=
|
991
|
+
dependencies=self._deps,
|
990
992
|
score_sproc_imports=['sklearn'],
|
991
993
|
)
|
992
994
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1051,11 +1053,8 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
1051
1053
|
|
1052
1054
|
if isinstance(dataset, DataFrame):
|
1053
1055
|
|
1054
|
-
self.
|
1055
|
-
|
1056
|
-
inference_method=inference_method,
|
1057
|
-
|
1058
|
-
)
|
1056
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1057
|
+
self._deps = self._get_dependencies()
|
1059
1058
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1060
1059
|
transform_kwargs = dict(
|
1061
1060
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn."
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ExtraTreeClassifier(BaseTransformer):
|
70
64
|
r"""An extremely randomized tree classifier
|
71
65
|
For more details on this class, see [sklearn.tree.ExtraTreeClassifier]
|
@@ -383,20 +377,17 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
383
377
|
self,
|
384
378
|
dataset: DataFrame,
|
385
379
|
inference_method: str,
|
386
|
-
) ->
|
387
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
388
|
-
return the available package that exists in the snowflake anaconda channel
|
380
|
+
) -> None:
|
381
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
389
382
|
|
390
383
|
Args:
|
391
384
|
dataset: snowpark dataframe
|
392
385
|
inference_method: the inference method such as predict, score...
|
393
|
-
|
386
|
+
|
394
387
|
Raises:
|
395
388
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
396
389
|
SnowflakeMLException: If the session is None, raise error
|
397
390
|
|
398
|
-
Returns:
|
399
|
-
A list of available package that exists in the snowflake anaconda channel
|
400
391
|
"""
|
401
392
|
if not self._is_fitted:
|
402
393
|
raise exceptions.SnowflakeMLException(
|
@@ -414,9 +405,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
414
405
|
"Session must not specified for snowpark dataset."
|
415
406
|
),
|
416
407
|
)
|
417
|
-
|
418
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
419
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
408
|
+
|
420
409
|
|
421
410
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
422
411
|
@telemetry.send_api_usage_telemetry(
|
@@ -464,7 +453,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
464
453
|
|
465
454
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
466
455
|
|
467
|
-
self.
|
456
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
457
|
+
self._deps = self._get_dependencies()
|
468
458
|
assert isinstance(
|
469
459
|
dataset._session, Session
|
470
460
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -547,10 +537,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
547
537
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
548
538
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
549
539
|
|
550
|
-
self.
|
551
|
-
|
552
|
-
inference_method=inference_method,
|
553
|
-
)
|
540
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
541
|
+
self._deps = self._get_dependencies()
|
554
542
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
555
543
|
|
556
544
|
transform_kwargs = dict(
|
@@ -617,16 +605,40 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
617
605
|
self._is_fitted = True
|
618
606
|
return output_result
|
619
607
|
|
608
|
+
|
609
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
610
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
611
|
+
""" Method not supported for this class.
|
620
612
|
|
621
|
-
|
622
|
-
|
623
|
-
|
613
|
+
|
614
|
+
Raises:
|
615
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
616
|
+
|
617
|
+
Args:
|
618
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
619
|
+
Snowpark or Pandas DataFrame.
|
620
|
+
output_cols_prefix: Prefix for the response columns
|
624
621
|
Returns:
|
625
622
|
Transformed dataset.
|
626
623
|
"""
|
627
|
-
self.
|
628
|
-
|
629
|
-
|
624
|
+
self._infer_input_output_cols(dataset)
|
625
|
+
super()._check_dataset_type(dataset)
|
626
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
627
|
+
estimator=self._sklearn_object,
|
628
|
+
dataset=dataset,
|
629
|
+
input_cols=self.input_cols,
|
630
|
+
label_cols=self.label_cols,
|
631
|
+
sample_weight_col=self.sample_weight_col,
|
632
|
+
autogenerated=self._autogenerated,
|
633
|
+
subproject=_SUBPROJECT,
|
634
|
+
)
|
635
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
636
|
+
drop_input_cols=self._drop_input_cols,
|
637
|
+
expected_output_cols_list=self.output_cols,
|
638
|
+
)
|
639
|
+
self._sklearn_object = fitted_estimator
|
640
|
+
self._is_fitted = True
|
641
|
+
return output_result
|
630
642
|
|
631
643
|
|
632
644
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -719,10 +731,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
719
731
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
720
732
|
|
721
733
|
if isinstance(dataset, DataFrame):
|
722
|
-
self.
|
723
|
-
|
724
|
-
inference_method=inference_method,
|
725
|
-
)
|
734
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
735
|
+
self._deps = self._get_dependencies()
|
726
736
|
assert isinstance(
|
727
737
|
dataset._session, Session
|
728
738
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -789,10 +799,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
789
799
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
790
800
|
|
791
801
|
if isinstance(dataset, DataFrame):
|
792
|
-
self.
|
793
|
-
|
794
|
-
inference_method=inference_method,
|
795
|
-
)
|
802
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
803
|
+
self._deps = self._get_dependencies()
|
796
804
|
assert isinstance(
|
797
805
|
dataset._session, Session
|
798
806
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -854,10 +862,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
854
862
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
855
863
|
|
856
864
|
if isinstance(dataset, DataFrame):
|
857
|
-
self.
|
858
|
-
|
859
|
-
inference_method=inference_method,
|
860
|
-
)
|
865
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
866
|
+
self._deps = self._get_dependencies()
|
861
867
|
assert isinstance(
|
862
868
|
dataset._session, Session
|
863
869
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -923,10 +929,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
923
929
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
924
930
|
|
925
931
|
if isinstance(dataset, DataFrame):
|
926
|
-
self.
|
927
|
-
|
928
|
-
inference_method=inference_method,
|
929
|
-
)
|
932
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
933
|
+
self._deps = self._get_dependencies()
|
930
934
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
931
935
|
transform_kwargs = dict(
|
932
936
|
session=dataset._session,
|
@@ -990,17 +994,15 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
990
994
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
991
995
|
|
992
996
|
if isinstance(dataset, DataFrame):
|
993
|
-
self.
|
994
|
-
|
995
|
-
inference_method="score",
|
996
|
-
)
|
997
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
998
|
+
self._deps = self._get_dependencies()
|
997
999
|
selected_cols = self._get_active_columns()
|
998
1000
|
if len(selected_cols) > 0:
|
999
1001
|
dataset = dataset.select(selected_cols)
|
1000
1002
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1001
1003
|
transform_kwargs = dict(
|
1002
1004
|
session=dataset._session,
|
1003
|
-
dependencies=
|
1005
|
+
dependencies=self._deps,
|
1004
1006
|
score_sproc_imports=['sklearn'],
|
1005
1007
|
)
|
1006
1008
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1065,11 +1067,8 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1065
1067
|
|
1066
1068
|
if isinstance(dataset, DataFrame):
|
1067
1069
|
|
1068
|
-
self.
|
1069
|
-
|
1070
|
-
inference_method=inference_method,
|
1071
|
-
|
1072
|
-
)
|
1070
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1071
|
+
self._deps = self._get_dependencies()
|
1073
1072
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1074
1073
|
transform_kwargs = dict(
|
1075
1074
|
session = dataset._session,
|