snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class ExtraTreesClassifier(BaseTransformer):
|
71
64
|
r"""An extra-trees classifier
|
72
65
|
For more details on this class, see [sklearn.ensemble.ExtraTreesClassifier]
|
@@ -379,12 +372,7 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
379
372
|
)
|
380
373
|
return selected_cols
|
381
374
|
|
382
|
-
|
383
|
-
project=_PROJECT,
|
384
|
-
subproject=_SUBPROJECT,
|
385
|
-
custom_tags=dict([("autogen", True)]),
|
386
|
-
)
|
387
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ExtraTreesClassifier":
|
375
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ExtraTreesClassifier":
|
388
376
|
"""Build a forest of trees from the training set (X, y)
|
389
377
|
For more details on this function, see [sklearn.ensemble.ExtraTreesClassifier.fit]
|
390
378
|
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html#sklearn.ensemble.ExtraTreesClassifier.fit)
|
@@ -411,12 +399,14 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
411
399
|
|
412
400
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
413
401
|
|
414
|
-
|
402
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
415
403
|
if SNOWML_SPROC_ENV in os.environ:
|
416
404
|
statement_params = telemetry.get_function_usage_statement_params(
|
417
405
|
project=_PROJECT,
|
418
406
|
subproject=_SUBPROJECT,
|
419
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
407
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
408
|
+
inspect.currentframe(), ExtraTreesClassifier.__class__.__name__
|
409
|
+
),
|
420
410
|
api_calls=[Session.call],
|
421
411
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
422
412
|
)
|
@@ -437,27 +427,24 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
437
427
|
)
|
438
428
|
self._sklearn_object = model_trainer.train()
|
439
429
|
self._is_fitted = True
|
440
|
-
self.
|
430
|
+
self._generate_model_signatures(dataset)
|
441
431
|
return self
|
442
432
|
|
443
433
|
def _batch_inference_validate_snowpark(
|
444
434
|
self,
|
445
435
|
dataset: DataFrame,
|
446
436
|
inference_method: str,
|
447
|
-
) ->
|
448
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
449
|
-
return the available package that exists in the snowflake anaconda channel
|
437
|
+
) -> None:
|
438
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
450
439
|
|
451
440
|
Args:
|
452
441
|
dataset: snowpark dataframe
|
453
442
|
inference_method: the inference method such as predict, score...
|
454
|
-
|
443
|
+
|
455
444
|
Raises:
|
456
445
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
457
446
|
SnowflakeMLException: If the session is None, raise error
|
458
447
|
|
459
|
-
Returns:
|
460
|
-
A list of available package that exists in the snowflake anaconda channel
|
461
448
|
"""
|
462
449
|
if not self._is_fitted:
|
463
450
|
raise exceptions.SnowflakeMLException(
|
@@ -475,9 +462,7 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
475
462
|
"Session must not specified for snowpark dataset."
|
476
463
|
),
|
477
464
|
)
|
478
|
-
|
479
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
480
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
465
|
+
|
481
466
|
|
482
467
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
483
468
|
@telemetry.send_api_usage_telemetry(
|
@@ -513,7 +498,9 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
513
498
|
# when it is classifier, infer the datatype from label columns
|
514
499
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
515
500
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
516
|
-
label_cols_signatures = [
|
501
|
+
label_cols_signatures = [
|
502
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
503
|
+
]
|
517
504
|
if len(label_cols_signatures) == 0:
|
518
505
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
519
506
|
raise exceptions.SnowflakeMLException(
|
@@ -521,25 +508,23 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
521
508
|
original_exception=ValueError(error_str),
|
522
509
|
)
|
523
510
|
|
524
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
525
|
-
label_cols_signatures[0].as_snowpark_type()
|
526
|
-
)
|
511
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
527
512
|
|
528
|
-
self.
|
529
|
-
|
513
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
514
|
+
self._deps = self._get_dependencies()
|
515
|
+
assert isinstance(
|
516
|
+
dataset._session, Session
|
517
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
530
518
|
|
531
519
|
transform_kwargs = dict(
|
532
|
-
session
|
533
|
-
dependencies
|
534
|
-
drop_input_cols
|
535
|
-
expected_output_cols_type
|
520
|
+
session=dataset._session,
|
521
|
+
dependencies=self._deps,
|
522
|
+
drop_input_cols=self._drop_input_cols,
|
523
|
+
expected_output_cols_type=expected_type_inferred,
|
536
524
|
)
|
537
525
|
|
538
526
|
elif isinstance(dataset, pd.DataFrame):
|
539
|
-
transform_kwargs = dict(
|
540
|
-
snowpark_input_cols = self._snowpark_cols,
|
541
|
-
drop_input_cols = self._drop_input_cols
|
542
|
-
)
|
527
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
543
528
|
|
544
529
|
transform_handlers = ModelTransformerBuilder.build(
|
545
530
|
dataset=dataset,
|
@@ -579,7 +564,7 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
579
564
|
Transformed dataset.
|
580
565
|
"""
|
581
566
|
super()._check_dataset_type(dataset)
|
582
|
-
inference_method="transform"
|
567
|
+
inference_method = "transform"
|
583
568
|
|
584
569
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
585
570
|
# are specific to the type of dataset used.
|
@@ -609,24 +594,19 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
609
594
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
610
595
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
611
596
|
|
612
|
-
self.
|
613
|
-
|
614
|
-
inference_method=inference_method,
|
615
|
-
)
|
597
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
598
|
+
self._deps = self._get_dependencies()
|
616
599
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
617
600
|
|
618
601
|
transform_kwargs = dict(
|
619
|
-
session
|
620
|
-
dependencies
|
621
|
-
drop_input_cols
|
622
|
-
expected_output_cols_type
|
602
|
+
session=dataset._session,
|
603
|
+
dependencies=self._deps,
|
604
|
+
drop_input_cols=self._drop_input_cols,
|
605
|
+
expected_output_cols_type=expected_dtype,
|
623
606
|
)
|
624
607
|
|
625
608
|
elif isinstance(dataset, pd.DataFrame):
|
626
|
-
transform_kwargs = dict(
|
627
|
-
snowpark_input_cols = self._snowpark_cols,
|
628
|
-
drop_input_cols = self._drop_input_cols
|
629
|
-
)
|
609
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
630
610
|
|
631
611
|
transform_handlers = ModelTransformerBuilder.build(
|
632
612
|
dataset=dataset,
|
@@ -645,7 +625,11 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
645
625
|
return output_df
|
646
626
|
|
647
627
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
648
|
-
def fit_predict(
|
628
|
+
def fit_predict(
|
629
|
+
self,
|
630
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
631
|
+
output_cols_prefix: str = "fit_predict_",
|
632
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
649
633
|
""" Method not supported for this class.
|
650
634
|
|
651
635
|
|
@@ -670,22 +654,104 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
670
654
|
)
|
671
655
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
672
656
|
drop_input_cols=self._drop_input_cols,
|
673
|
-
expected_output_cols_list=
|
657
|
+
expected_output_cols_list=(
|
658
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
659
|
+
),
|
674
660
|
)
|
675
661
|
self._sklearn_object = fitted_estimator
|
676
662
|
self._is_fitted = True
|
677
663
|
return output_result
|
678
664
|
|
665
|
+
|
666
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
667
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
668
|
+
""" Method not supported for this class.
|
669
|
+
|
679
670
|
|
680
|
-
|
681
|
-
|
682
|
-
|
671
|
+
Raises:
|
672
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
673
|
+
|
674
|
+
Args:
|
675
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
676
|
+
Snowpark or Pandas DataFrame.
|
677
|
+
output_cols_prefix: Prefix for the response columns
|
683
678
|
Returns:
|
684
679
|
Transformed dataset.
|
685
680
|
"""
|
686
|
-
self.
|
687
|
-
|
688
|
-
|
681
|
+
self._infer_input_output_cols(dataset)
|
682
|
+
super()._check_dataset_type(dataset)
|
683
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
684
|
+
estimator=self._sklearn_object,
|
685
|
+
dataset=dataset,
|
686
|
+
input_cols=self.input_cols,
|
687
|
+
label_cols=self.label_cols,
|
688
|
+
sample_weight_col=self.sample_weight_col,
|
689
|
+
autogenerated=self._autogenerated,
|
690
|
+
subproject=_SUBPROJECT,
|
691
|
+
)
|
692
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
693
|
+
drop_input_cols=self._drop_input_cols,
|
694
|
+
expected_output_cols_list=self.output_cols,
|
695
|
+
)
|
696
|
+
self._sklearn_object = fitted_estimator
|
697
|
+
self._is_fitted = True
|
698
|
+
return output_result
|
699
|
+
|
700
|
+
|
701
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
702
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
703
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
704
|
+
"""
|
705
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
706
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
707
|
+
if output_cols:
|
708
|
+
output_cols = [
|
709
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
710
|
+
for c in output_cols
|
711
|
+
]
|
712
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
713
|
+
output_cols = [output_cols_prefix]
|
714
|
+
elif self._sklearn_object is not None:
|
715
|
+
classes = self._sklearn_object.classes_
|
716
|
+
if isinstance(classes, numpy.ndarray):
|
717
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
718
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
719
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
720
|
+
output_cols = []
|
721
|
+
for i, cl in enumerate(classes):
|
722
|
+
# For binary classification, there is only one output column for each class
|
723
|
+
# ndarray as the two classes are complementary.
|
724
|
+
if len(cl) == 2:
|
725
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
726
|
+
else:
|
727
|
+
output_cols.extend([
|
728
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
729
|
+
])
|
730
|
+
else:
|
731
|
+
output_cols = []
|
732
|
+
|
733
|
+
# Make sure column names are valid snowflake identifiers.
|
734
|
+
assert output_cols is not None # Make MyPy happy
|
735
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
736
|
+
|
737
|
+
return rv
|
738
|
+
|
739
|
+
def _align_expected_output_names(
|
740
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
741
|
+
) -> List[str]:
|
742
|
+
# in case the inferred output column names dimension is different
|
743
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
744
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
745
|
+
output_df_columns = list(output_df_pd.columns)
|
746
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
747
|
+
if self.sample_weight_col:
|
748
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
749
|
+
# if the dimension of inferred output column names is correct; use it
|
750
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
751
|
+
return expected_output_cols_list
|
752
|
+
# otherwise, use the sklearn estimator's output
|
753
|
+
else:
|
754
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
689
755
|
|
690
756
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
691
757
|
@telemetry.send_api_usage_telemetry(
|
@@ -719,24 +785,26 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
719
785
|
# are specific to the type of dataset used.
|
720
786
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
721
787
|
|
788
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
789
|
+
|
722
790
|
if isinstance(dataset, DataFrame):
|
723
|
-
self.
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
791
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
792
|
+
self._deps = self._get_dependencies()
|
793
|
+
assert isinstance(
|
794
|
+
dataset._session, Session
|
795
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
728
796
|
transform_kwargs = dict(
|
729
797
|
session=dataset._session,
|
730
798
|
dependencies=self._deps,
|
731
|
-
drop_input_cols
|
799
|
+
drop_input_cols=self._drop_input_cols,
|
732
800
|
expected_output_cols_type="float",
|
733
801
|
)
|
802
|
+
expected_output_cols = self._align_expected_output_names(
|
803
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
804
|
+
)
|
734
805
|
|
735
806
|
elif isinstance(dataset, pd.DataFrame):
|
736
|
-
transform_kwargs = dict(
|
737
|
-
snowpark_input_cols = self._snowpark_cols,
|
738
|
-
drop_input_cols = self._drop_input_cols
|
739
|
-
)
|
807
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
740
808
|
|
741
809
|
transform_handlers = ModelTransformerBuilder.build(
|
742
810
|
dataset=dataset,
|
@@ -748,7 +816,7 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
748
816
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
749
817
|
inference_method=inference_method,
|
750
818
|
input_cols=self.input_cols,
|
751
|
-
expected_output_cols=
|
819
|
+
expected_output_cols=expected_output_cols,
|
752
820
|
**transform_kwargs
|
753
821
|
)
|
754
822
|
return output_df
|
@@ -780,29 +848,30 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
780
848
|
Output dataset with log probability of the sample for each class in the model.
|
781
849
|
"""
|
782
850
|
super()._check_dataset_type(dataset)
|
783
|
-
inference_method="predict_log_proba"
|
851
|
+
inference_method = "predict_log_proba"
|
852
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
784
853
|
|
785
854
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
786
855
|
# are specific to the type of dataset used.
|
787
856
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
788
857
|
|
789
858
|
if isinstance(dataset, DataFrame):
|
790
|
-
self.
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
859
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
860
|
+
self._deps = self._get_dependencies()
|
861
|
+
assert isinstance(
|
862
|
+
dataset._session, Session
|
863
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
795
864
|
transform_kwargs = dict(
|
796
865
|
session=dataset._session,
|
797
866
|
dependencies=self._deps,
|
798
|
-
drop_input_cols
|
867
|
+
drop_input_cols=self._drop_input_cols,
|
799
868
|
expected_output_cols_type="float",
|
800
869
|
)
|
870
|
+
expected_output_cols = self._align_expected_output_names(
|
871
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
872
|
+
)
|
801
873
|
elif isinstance(dataset, pd.DataFrame):
|
802
|
-
transform_kwargs = dict(
|
803
|
-
snowpark_input_cols = self._snowpark_cols,
|
804
|
-
drop_input_cols = self._drop_input_cols
|
805
|
-
)
|
874
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
806
875
|
|
807
876
|
transform_handlers = ModelTransformerBuilder.build(
|
808
877
|
dataset=dataset,
|
@@ -815,7 +884,7 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
815
884
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
816
885
|
inference_method=inference_method,
|
817
886
|
input_cols=self.input_cols,
|
818
|
-
expected_output_cols=
|
887
|
+
expected_output_cols=expected_output_cols,
|
819
888
|
**transform_kwargs
|
820
889
|
)
|
821
890
|
return output_df
|
@@ -841,30 +910,32 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
841
910
|
Output dataset with results of the decision function for the samples in input dataset.
|
842
911
|
"""
|
843
912
|
super()._check_dataset_type(dataset)
|
844
|
-
inference_method="decision_function"
|
913
|
+
inference_method = "decision_function"
|
845
914
|
|
846
915
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
847
916
|
# are specific to the type of dataset used.
|
848
917
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
849
918
|
|
919
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
920
|
+
|
850
921
|
if isinstance(dataset, DataFrame):
|
851
|
-
self.
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
922
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
923
|
+
self._deps = self._get_dependencies()
|
924
|
+
assert isinstance(
|
925
|
+
dataset._session, Session
|
926
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
856
927
|
transform_kwargs = dict(
|
857
928
|
session=dataset._session,
|
858
929
|
dependencies=self._deps,
|
859
|
-
drop_input_cols
|
930
|
+
drop_input_cols=self._drop_input_cols,
|
860
931
|
expected_output_cols_type="float",
|
861
932
|
)
|
933
|
+
expected_output_cols = self._align_expected_output_names(
|
934
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
935
|
+
)
|
862
936
|
|
863
937
|
elif isinstance(dataset, pd.DataFrame):
|
864
|
-
transform_kwargs = dict(
|
865
|
-
snowpark_input_cols = self._snowpark_cols,
|
866
|
-
drop_input_cols = self._drop_input_cols
|
867
|
-
)
|
938
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
868
939
|
|
869
940
|
transform_handlers = ModelTransformerBuilder.build(
|
870
941
|
dataset=dataset,
|
@@ -877,7 +948,7 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
877
948
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
878
949
|
inference_method=inference_method,
|
879
950
|
input_cols=self.input_cols,
|
880
|
-
expected_output_cols=
|
951
|
+
expected_output_cols=expected_output_cols,
|
881
952
|
**transform_kwargs
|
882
953
|
)
|
883
954
|
return output_df
|
@@ -906,17 +977,17 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
906
977
|
Output dataset with probability of the sample for each class in the model.
|
907
978
|
"""
|
908
979
|
super()._check_dataset_type(dataset)
|
909
|
-
inference_method="score_samples"
|
980
|
+
inference_method = "score_samples"
|
910
981
|
|
911
982
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
912
983
|
# are specific to the type of dataset used.
|
913
984
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
914
985
|
|
986
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
987
|
+
|
915
988
|
if isinstance(dataset, DataFrame):
|
916
|
-
self.
|
917
|
-
|
918
|
-
inference_method=inference_method,
|
919
|
-
)
|
989
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
990
|
+
self._deps = self._get_dependencies()
|
920
991
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
921
992
|
transform_kwargs = dict(
|
922
993
|
session=dataset._session,
|
@@ -924,6 +995,9 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
924
995
|
drop_input_cols = self._drop_input_cols,
|
925
996
|
expected_output_cols_type="float",
|
926
997
|
)
|
998
|
+
expected_output_cols = self._align_expected_output_names(
|
999
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
1000
|
+
)
|
927
1001
|
|
928
1002
|
elif isinstance(dataset, pd.DataFrame):
|
929
1003
|
transform_kwargs = dict(
|
@@ -942,7 +1016,7 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
942
1016
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
943
1017
|
inference_method=inference_method,
|
944
1018
|
input_cols=self.input_cols,
|
945
|
-
expected_output_cols=
|
1019
|
+
expected_output_cols=expected_output_cols,
|
946
1020
|
**transform_kwargs
|
947
1021
|
)
|
948
1022
|
return output_df
|
@@ -977,17 +1051,15 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
977
1051
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
978
1052
|
|
979
1053
|
if isinstance(dataset, DataFrame):
|
980
|
-
self.
|
981
|
-
|
982
|
-
inference_method="score",
|
983
|
-
)
|
1054
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1055
|
+
self._deps = self._get_dependencies()
|
984
1056
|
selected_cols = self._get_active_columns()
|
985
1057
|
if len(selected_cols) > 0:
|
986
1058
|
dataset = dataset.select(selected_cols)
|
987
1059
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
988
1060
|
transform_kwargs = dict(
|
989
1061
|
session=dataset._session,
|
990
|
-
dependencies=
|
1062
|
+
dependencies=self._deps,
|
991
1063
|
score_sproc_imports=['sklearn'],
|
992
1064
|
)
|
993
1065
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1052,11 +1124,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
1052
1124
|
|
1053
1125
|
if isinstance(dataset, DataFrame):
|
1054
1126
|
|
1055
|
-
self.
|
1056
|
-
|
1057
|
-
inference_method=inference_method,
|
1058
|
-
|
1059
|
-
)
|
1127
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1128
|
+
self._deps = self._get_dependencies()
|
1060
1129
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1061
1130
|
transform_kwargs = dict(
|
1062
1131
|
session = dataset._session,
|
@@ -1089,50 +1158,84 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
1089
1158
|
)
|
1090
1159
|
return output_df
|
1091
1160
|
|
1161
|
+
|
1162
|
+
|
1163
|
+
def to_sklearn(self) -> Any:
|
1164
|
+
"""Get sklearn.ensemble.ExtraTreesClassifier object.
|
1165
|
+
"""
|
1166
|
+
if self._sklearn_object is None:
|
1167
|
+
self._sklearn_object = self._create_sklearn_object()
|
1168
|
+
return self._sklearn_object
|
1169
|
+
|
1170
|
+
def to_xgboost(self) -> Any:
|
1171
|
+
raise exceptions.SnowflakeMLException(
|
1172
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1173
|
+
original_exception=AttributeError(
|
1174
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1175
|
+
"to_xgboost()",
|
1176
|
+
"to_sklearn()"
|
1177
|
+
)
|
1178
|
+
),
|
1179
|
+
)
|
1180
|
+
|
1181
|
+
def to_lightgbm(self) -> Any:
|
1182
|
+
raise exceptions.SnowflakeMLException(
|
1183
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1184
|
+
original_exception=AttributeError(
|
1185
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1186
|
+
"to_lightgbm()",
|
1187
|
+
"to_sklearn()"
|
1188
|
+
)
|
1189
|
+
),
|
1190
|
+
)
|
1191
|
+
|
1192
|
+
def _get_dependencies(self) -> List[str]:
|
1193
|
+
return self._deps
|
1194
|
+
|
1092
1195
|
|
1093
|
-
def
|
1196
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1094
1197
|
self._model_signature_dict = dict()
|
1095
1198
|
|
1096
1199
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1097
1200
|
|
1098
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1201
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1099
1202
|
outputs: List[BaseFeatureSpec] = []
|
1100
1203
|
if hasattr(self, "predict"):
|
1101
1204
|
# keep mypy happy
|
1102
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1205
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1103
1206
|
# For classifier, the type of predict is the same as the type of label
|
1104
|
-
if self._sklearn_object._estimator_type ==
|
1105
|
-
|
1207
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1208
|
+
# label columns is the desired type for output
|
1106
1209
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1107
1210
|
# rename the output columns
|
1108
1211
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1109
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1110
|
-
|
1111
|
-
|
1212
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1213
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1214
|
+
)
|
1112
1215
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1113
1216
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1114
|
-
# Clusterer returns int64 cluster labels.
|
1217
|
+
# Clusterer returns int64 cluster labels.
|
1115
1218
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1116
1219
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1117
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1220
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1221
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1222
|
+
)
|
1223
|
+
|
1121
1224
|
# For regressor, the type of predict is float64
|
1122
|
-
elif self._sklearn_object._estimator_type ==
|
1225
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1123
1226
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1124
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1227
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1228
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1229
|
+
)
|
1230
|
+
|
1128
1231
|
for prob_func in PROB_FUNCTIONS:
|
1129
1232
|
if hasattr(self, prob_func):
|
1130
1233
|
output_cols_prefix: str = f"{prob_func}_"
|
1131
1234
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1132
1235
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1133
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1134
|
-
|
1135
|
-
|
1236
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1237
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1238
|
+
)
|
1136
1239
|
|
1137
1240
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1138
1241
|
items = list(self._model_signature_dict.items())
|
@@ -1145,10 +1248,10 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
1145
1248
|
"""Returns model signature of current class.
|
1146
1249
|
|
1147
1250
|
Raises:
|
1148
|
-
|
1251
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1149
1252
|
|
1150
1253
|
Returns:
|
1151
|
-
Dict
|
1254
|
+
Dict with each method and its input output signature
|
1152
1255
|
"""
|
1153
1256
|
if self._model_signature_dict is None:
|
1154
1257
|
raise exceptions.SnowflakeMLException(
|
@@ -1156,35 +1259,3 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
1156
1259
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1157
1260
|
)
|
1158
1261
|
return self._model_signature_dict
|
1159
|
-
|
1160
|
-
def to_sklearn(self) -> Any:
|
1161
|
-
"""Get sklearn.ensemble.ExtraTreesClassifier object.
|
1162
|
-
"""
|
1163
|
-
if self._sklearn_object is None:
|
1164
|
-
self._sklearn_object = self._create_sklearn_object()
|
1165
|
-
return self._sklearn_object
|
1166
|
-
|
1167
|
-
def to_xgboost(self) -> Any:
|
1168
|
-
raise exceptions.SnowflakeMLException(
|
1169
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1170
|
-
original_exception=AttributeError(
|
1171
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1172
|
-
"to_xgboost()",
|
1173
|
-
"to_sklearn()"
|
1174
|
-
)
|
1175
|
-
),
|
1176
|
-
)
|
1177
|
-
|
1178
|
-
def to_lightgbm(self) -> Any:
|
1179
|
-
raise exceptions.SnowflakeMLException(
|
1180
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1181
|
-
original_exception=AttributeError(
|
1182
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1183
|
-
"to_lightgbm()",
|
1184
|
-
"to_sklearn()"
|
1185
|
-
)
|
1186
|
-
),
|
1187
|
-
)
|
1188
|
-
|
1189
|
-
def _get_dependencies(self) -> List[str]:
|
1190
|
-
return self._deps
|