snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class ExtraTreesRegressor(BaseTransformer):
|
71
64
|
r"""An extra-trees regressor
|
72
65
|
For more details on this class, see [sklearn.ensemble.ExtraTreesRegressor]
|
@@ -358,12 +351,7 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
358
351
|
)
|
359
352
|
return selected_cols
|
360
353
|
|
361
|
-
|
362
|
-
project=_PROJECT,
|
363
|
-
subproject=_SUBPROJECT,
|
364
|
-
custom_tags=dict([("autogen", True)]),
|
365
|
-
)
|
366
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ExtraTreesRegressor":
|
354
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ExtraTreesRegressor":
|
367
355
|
"""Build a forest of trees from the training set (X, y)
|
368
356
|
For more details on this function, see [sklearn.ensemble.ExtraTreesRegressor.fit]
|
369
357
|
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesRegressor.html#sklearn.ensemble.ExtraTreesRegressor.fit)
|
@@ -390,12 +378,14 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
390
378
|
|
391
379
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
392
380
|
|
393
|
-
|
381
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
394
382
|
if SNOWML_SPROC_ENV in os.environ:
|
395
383
|
statement_params = telemetry.get_function_usage_statement_params(
|
396
384
|
project=_PROJECT,
|
397
385
|
subproject=_SUBPROJECT,
|
398
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
386
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
387
|
+
inspect.currentframe(), ExtraTreesRegressor.__class__.__name__
|
388
|
+
),
|
399
389
|
api_calls=[Session.call],
|
400
390
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
401
391
|
)
|
@@ -416,27 +406,24 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
416
406
|
)
|
417
407
|
self._sklearn_object = model_trainer.train()
|
418
408
|
self._is_fitted = True
|
419
|
-
self.
|
409
|
+
self._generate_model_signatures(dataset)
|
420
410
|
return self
|
421
411
|
|
422
412
|
def _batch_inference_validate_snowpark(
|
423
413
|
self,
|
424
414
|
dataset: DataFrame,
|
425
415
|
inference_method: str,
|
426
|
-
) ->
|
427
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
428
|
-
return the available package that exists in the snowflake anaconda channel
|
416
|
+
) -> None:
|
417
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
429
418
|
|
430
419
|
Args:
|
431
420
|
dataset: snowpark dataframe
|
432
421
|
inference_method: the inference method such as predict, score...
|
433
|
-
|
422
|
+
|
434
423
|
Raises:
|
435
424
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
436
425
|
SnowflakeMLException: If the session is None, raise error
|
437
426
|
|
438
|
-
Returns:
|
439
|
-
A list of available package that exists in the snowflake anaconda channel
|
440
427
|
"""
|
441
428
|
if not self._is_fitted:
|
442
429
|
raise exceptions.SnowflakeMLException(
|
@@ -454,9 +441,7 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
454
441
|
"Session must not specified for snowpark dataset."
|
455
442
|
),
|
456
443
|
)
|
457
|
-
|
458
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
459
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
444
|
+
|
460
445
|
|
461
446
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
462
447
|
@telemetry.send_api_usage_telemetry(
|
@@ -492,7 +477,9 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
492
477
|
# when it is classifier, infer the datatype from label columns
|
493
478
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
494
479
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
495
|
-
label_cols_signatures = [
|
480
|
+
label_cols_signatures = [
|
481
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
482
|
+
]
|
496
483
|
if len(label_cols_signatures) == 0:
|
497
484
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
498
485
|
raise exceptions.SnowflakeMLException(
|
@@ -500,25 +487,23 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
500
487
|
original_exception=ValueError(error_str),
|
501
488
|
)
|
502
489
|
|
503
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
504
|
-
label_cols_signatures[0].as_snowpark_type()
|
505
|
-
)
|
490
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
506
491
|
|
507
|
-
self.
|
508
|
-
|
492
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
493
|
+
self._deps = self._get_dependencies()
|
494
|
+
assert isinstance(
|
495
|
+
dataset._session, Session
|
496
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
509
497
|
|
510
498
|
transform_kwargs = dict(
|
511
|
-
session
|
512
|
-
dependencies
|
513
|
-
drop_input_cols
|
514
|
-
expected_output_cols_type
|
499
|
+
session=dataset._session,
|
500
|
+
dependencies=self._deps,
|
501
|
+
drop_input_cols=self._drop_input_cols,
|
502
|
+
expected_output_cols_type=expected_type_inferred,
|
515
503
|
)
|
516
504
|
|
517
505
|
elif isinstance(dataset, pd.DataFrame):
|
518
|
-
transform_kwargs = dict(
|
519
|
-
snowpark_input_cols = self._snowpark_cols,
|
520
|
-
drop_input_cols = self._drop_input_cols
|
521
|
-
)
|
506
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
522
507
|
|
523
508
|
transform_handlers = ModelTransformerBuilder.build(
|
524
509
|
dataset=dataset,
|
@@ -558,7 +543,7 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
558
543
|
Transformed dataset.
|
559
544
|
"""
|
560
545
|
super()._check_dataset_type(dataset)
|
561
|
-
inference_method="transform"
|
546
|
+
inference_method = "transform"
|
562
547
|
|
563
548
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
564
549
|
# are specific to the type of dataset used.
|
@@ -588,24 +573,19 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
588
573
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
589
574
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
590
575
|
|
591
|
-
self.
|
592
|
-
|
593
|
-
inference_method=inference_method,
|
594
|
-
)
|
576
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
577
|
+
self._deps = self._get_dependencies()
|
595
578
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
596
579
|
|
597
580
|
transform_kwargs = dict(
|
598
|
-
session
|
599
|
-
dependencies
|
600
|
-
drop_input_cols
|
601
|
-
expected_output_cols_type
|
581
|
+
session=dataset._session,
|
582
|
+
dependencies=self._deps,
|
583
|
+
drop_input_cols=self._drop_input_cols,
|
584
|
+
expected_output_cols_type=expected_dtype,
|
602
585
|
)
|
603
586
|
|
604
587
|
elif isinstance(dataset, pd.DataFrame):
|
605
|
-
transform_kwargs = dict(
|
606
|
-
snowpark_input_cols = self._snowpark_cols,
|
607
|
-
drop_input_cols = self._drop_input_cols
|
608
|
-
)
|
588
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
609
589
|
|
610
590
|
transform_handlers = ModelTransformerBuilder.build(
|
611
591
|
dataset=dataset,
|
@@ -624,7 +604,11 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
624
604
|
return output_df
|
625
605
|
|
626
606
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
627
|
-
def fit_predict(
|
607
|
+
def fit_predict(
|
608
|
+
self,
|
609
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
610
|
+
output_cols_prefix: str = "fit_predict_",
|
611
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
628
612
|
""" Method not supported for this class.
|
629
613
|
|
630
614
|
|
@@ -649,22 +633,104 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
649
633
|
)
|
650
634
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
651
635
|
drop_input_cols=self._drop_input_cols,
|
652
|
-
expected_output_cols_list=
|
636
|
+
expected_output_cols_list=(
|
637
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
638
|
+
),
|
653
639
|
)
|
654
640
|
self._sklearn_object = fitted_estimator
|
655
641
|
self._is_fitted = True
|
656
642
|
return output_result
|
657
643
|
|
644
|
+
|
645
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
646
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
647
|
+
""" Method not supported for this class.
|
648
|
+
|
658
649
|
|
659
|
-
|
660
|
-
|
661
|
-
|
650
|
+
Raises:
|
651
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
652
|
+
|
653
|
+
Args:
|
654
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
655
|
+
Snowpark or Pandas DataFrame.
|
656
|
+
output_cols_prefix: Prefix for the response columns
|
662
657
|
Returns:
|
663
658
|
Transformed dataset.
|
664
659
|
"""
|
665
|
-
self.
|
666
|
-
|
667
|
-
|
660
|
+
self._infer_input_output_cols(dataset)
|
661
|
+
super()._check_dataset_type(dataset)
|
662
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
663
|
+
estimator=self._sklearn_object,
|
664
|
+
dataset=dataset,
|
665
|
+
input_cols=self.input_cols,
|
666
|
+
label_cols=self.label_cols,
|
667
|
+
sample_weight_col=self.sample_weight_col,
|
668
|
+
autogenerated=self._autogenerated,
|
669
|
+
subproject=_SUBPROJECT,
|
670
|
+
)
|
671
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
672
|
+
drop_input_cols=self._drop_input_cols,
|
673
|
+
expected_output_cols_list=self.output_cols,
|
674
|
+
)
|
675
|
+
self._sklearn_object = fitted_estimator
|
676
|
+
self._is_fitted = True
|
677
|
+
return output_result
|
678
|
+
|
679
|
+
|
680
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
681
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
682
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
683
|
+
"""
|
684
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
685
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
686
|
+
if output_cols:
|
687
|
+
output_cols = [
|
688
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
689
|
+
for c in output_cols
|
690
|
+
]
|
691
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
692
|
+
output_cols = [output_cols_prefix]
|
693
|
+
elif self._sklearn_object is not None:
|
694
|
+
classes = self._sklearn_object.classes_
|
695
|
+
if isinstance(classes, numpy.ndarray):
|
696
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
697
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
698
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
699
|
+
output_cols = []
|
700
|
+
for i, cl in enumerate(classes):
|
701
|
+
# For binary classification, there is only one output column for each class
|
702
|
+
# ndarray as the two classes are complementary.
|
703
|
+
if len(cl) == 2:
|
704
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
705
|
+
else:
|
706
|
+
output_cols.extend([
|
707
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
708
|
+
])
|
709
|
+
else:
|
710
|
+
output_cols = []
|
711
|
+
|
712
|
+
# Make sure column names are valid snowflake identifiers.
|
713
|
+
assert output_cols is not None # Make MyPy happy
|
714
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
715
|
+
|
716
|
+
return rv
|
717
|
+
|
718
|
+
def _align_expected_output_names(
|
719
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
720
|
+
) -> List[str]:
|
721
|
+
# in case the inferred output column names dimension is different
|
722
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
723
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
724
|
+
output_df_columns = list(output_df_pd.columns)
|
725
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
726
|
+
if self.sample_weight_col:
|
727
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
728
|
+
# if the dimension of inferred output column names is correct; use it
|
729
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
730
|
+
return expected_output_cols_list
|
731
|
+
# otherwise, use the sklearn estimator's output
|
732
|
+
else:
|
733
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
668
734
|
|
669
735
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
670
736
|
@telemetry.send_api_usage_telemetry(
|
@@ -696,24 +762,26 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
696
762
|
# are specific to the type of dataset used.
|
697
763
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
698
764
|
|
765
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
766
|
+
|
699
767
|
if isinstance(dataset, DataFrame):
|
700
|
-
self.
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
768
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
769
|
+
self._deps = self._get_dependencies()
|
770
|
+
assert isinstance(
|
771
|
+
dataset._session, Session
|
772
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
705
773
|
transform_kwargs = dict(
|
706
774
|
session=dataset._session,
|
707
775
|
dependencies=self._deps,
|
708
|
-
drop_input_cols
|
776
|
+
drop_input_cols=self._drop_input_cols,
|
709
777
|
expected_output_cols_type="float",
|
710
778
|
)
|
779
|
+
expected_output_cols = self._align_expected_output_names(
|
780
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
781
|
+
)
|
711
782
|
|
712
783
|
elif isinstance(dataset, pd.DataFrame):
|
713
|
-
transform_kwargs = dict(
|
714
|
-
snowpark_input_cols = self._snowpark_cols,
|
715
|
-
drop_input_cols = self._drop_input_cols
|
716
|
-
)
|
784
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
717
785
|
|
718
786
|
transform_handlers = ModelTransformerBuilder.build(
|
719
787
|
dataset=dataset,
|
@@ -725,7 +793,7 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
725
793
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
726
794
|
inference_method=inference_method,
|
727
795
|
input_cols=self.input_cols,
|
728
|
-
expected_output_cols=
|
796
|
+
expected_output_cols=expected_output_cols,
|
729
797
|
**transform_kwargs
|
730
798
|
)
|
731
799
|
return output_df
|
@@ -755,29 +823,30 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
755
823
|
Output dataset with log probability of the sample for each class in the model.
|
756
824
|
"""
|
757
825
|
super()._check_dataset_type(dataset)
|
758
|
-
inference_method="predict_log_proba"
|
826
|
+
inference_method = "predict_log_proba"
|
827
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
759
828
|
|
760
829
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
761
830
|
# are specific to the type of dataset used.
|
762
831
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
763
832
|
|
764
833
|
if isinstance(dataset, DataFrame):
|
765
|
-
self.
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
834
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
835
|
+
self._deps = self._get_dependencies()
|
836
|
+
assert isinstance(
|
837
|
+
dataset._session, Session
|
838
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
770
839
|
transform_kwargs = dict(
|
771
840
|
session=dataset._session,
|
772
841
|
dependencies=self._deps,
|
773
|
-
drop_input_cols
|
842
|
+
drop_input_cols=self._drop_input_cols,
|
774
843
|
expected_output_cols_type="float",
|
775
844
|
)
|
845
|
+
expected_output_cols = self._align_expected_output_names(
|
846
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
847
|
+
)
|
776
848
|
elif isinstance(dataset, pd.DataFrame):
|
777
|
-
transform_kwargs = dict(
|
778
|
-
snowpark_input_cols = self._snowpark_cols,
|
779
|
-
drop_input_cols = self._drop_input_cols
|
780
|
-
)
|
849
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
781
850
|
|
782
851
|
transform_handlers = ModelTransformerBuilder.build(
|
783
852
|
dataset=dataset,
|
@@ -790,7 +859,7 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
790
859
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
791
860
|
inference_method=inference_method,
|
792
861
|
input_cols=self.input_cols,
|
793
|
-
expected_output_cols=
|
862
|
+
expected_output_cols=expected_output_cols,
|
794
863
|
**transform_kwargs
|
795
864
|
)
|
796
865
|
return output_df
|
@@ -816,30 +885,32 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
816
885
|
Output dataset with results of the decision function for the samples in input dataset.
|
817
886
|
"""
|
818
887
|
super()._check_dataset_type(dataset)
|
819
|
-
inference_method="decision_function"
|
888
|
+
inference_method = "decision_function"
|
820
889
|
|
821
890
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
822
891
|
# are specific to the type of dataset used.
|
823
892
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
824
893
|
|
894
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
895
|
+
|
825
896
|
if isinstance(dataset, DataFrame):
|
826
|
-
self.
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
897
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
898
|
+
self._deps = self._get_dependencies()
|
899
|
+
assert isinstance(
|
900
|
+
dataset._session, Session
|
901
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
831
902
|
transform_kwargs = dict(
|
832
903
|
session=dataset._session,
|
833
904
|
dependencies=self._deps,
|
834
|
-
drop_input_cols
|
905
|
+
drop_input_cols=self._drop_input_cols,
|
835
906
|
expected_output_cols_type="float",
|
836
907
|
)
|
908
|
+
expected_output_cols = self._align_expected_output_names(
|
909
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
910
|
+
)
|
837
911
|
|
838
912
|
elif isinstance(dataset, pd.DataFrame):
|
839
|
-
transform_kwargs = dict(
|
840
|
-
snowpark_input_cols = self._snowpark_cols,
|
841
|
-
drop_input_cols = self._drop_input_cols
|
842
|
-
)
|
913
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
843
914
|
|
844
915
|
transform_handlers = ModelTransformerBuilder.build(
|
845
916
|
dataset=dataset,
|
@@ -852,7 +923,7 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
852
923
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
853
924
|
inference_method=inference_method,
|
854
925
|
input_cols=self.input_cols,
|
855
|
-
expected_output_cols=
|
926
|
+
expected_output_cols=expected_output_cols,
|
856
927
|
**transform_kwargs
|
857
928
|
)
|
858
929
|
return output_df
|
@@ -881,17 +952,17 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
881
952
|
Output dataset with probability of the sample for each class in the model.
|
882
953
|
"""
|
883
954
|
super()._check_dataset_type(dataset)
|
884
|
-
inference_method="score_samples"
|
955
|
+
inference_method = "score_samples"
|
885
956
|
|
886
957
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
887
958
|
# are specific to the type of dataset used.
|
888
959
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
889
960
|
|
961
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
962
|
+
|
890
963
|
if isinstance(dataset, DataFrame):
|
891
|
-
self.
|
892
|
-
|
893
|
-
inference_method=inference_method,
|
894
|
-
)
|
964
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
965
|
+
self._deps = self._get_dependencies()
|
895
966
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
896
967
|
transform_kwargs = dict(
|
897
968
|
session=dataset._session,
|
@@ -899,6 +970,9 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
899
970
|
drop_input_cols = self._drop_input_cols,
|
900
971
|
expected_output_cols_type="float",
|
901
972
|
)
|
973
|
+
expected_output_cols = self._align_expected_output_names(
|
974
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
975
|
+
)
|
902
976
|
|
903
977
|
elif isinstance(dataset, pd.DataFrame):
|
904
978
|
transform_kwargs = dict(
|
@@ -917,7 +991,7 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
917
991
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
918
992
|
inference_method=inference_method,
|
919
993
|
input_cols=self.input_cols,
|
920
|
-
expected_output_cols=
|
994
|
+
expected_output_cols=expected_output_cols,
|
921
995
|
**transform_kwargs
|
922
996
|
)
|
923
997
|
return output_df
|
@@ -952,17 +1026,15 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
952
1026
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
953
1027
|
|
954
1028
|
if isinstance(dataset, DataFrame):
|
955
|
-
self.
|
956
|
-
|
957
|
-
inference_method="score",
|
958
|
-
)
|
1029
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1030
|
+
self._deps = self._get_dependencies()
|
959
1031
|
selected_cols = self._get_active_columns()
|
960
1032
|
if len(selected_cols) > 0:
|
961
1033
|
dataset = dataset.select(selected_cols)
|
962
1034
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
963
1035
|
transform_kwargs = dict(
|
964
1036
|
session=dataset._session,
|
965
|
-
dependencies=
|
1037
|
+
dependencies=self._deps,
|
966
1038
|
score_sproc_imports=['sklearn'],
|
967
1039
|
)
|
968
1040
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1027,11 +1099,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
1027
1099
|
|
1028
1100
|
if isinstance(dataset, DataFrame):
|
1029
1101
|
|
1030
|
-
self.
|
1031
|
-
|
1032
|
-
inference_method=inference_method,
|
1033
|
-
|
1034
|
-
)
|
1102
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1103
|
+
self._deps = self._get_dependencies()
|
1035
1104
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1036
1105
|
transform_kwargs = dict(
|
1037
1106
|
session = dataset._session,
|
@@ -1064,50 +1133,84 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
1064
1133
|
)
|
1065
1134
|
return output_df
|
1066
1135
|
|
1136
|
+
|
1137
|
+
|
1138
|
+
def to_sklearn(self) -> Any:
|
1139
|
+
"""Get sklearn.ensemble.ExtraTreesRegressor object.
|
1140
|
+
"""
|
1141
|
+
if self._sklearn_object is None:
|
1142
|
+
self._sklearn_object = self._create_sklearn_object()
|
1143
|
+
return self._sklearn_object
|
1144
|
+
|
1145
|
+
def to_xgboost(self) -> Any:
|
1146
|
+
raise exceptions.SnowflakeMLException(
|
1147
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1148
|
+
original_exception=AttributeError(
|
1149
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1150
|
+
"to_xgboost()",
|
1151
|
+
"to_sklearn()"
|
1152
|
+
)
|
1153
|
+
),
|
1154
|
+
)
|
1155
|
+
|
1156
|
+
def to_lightgbm(self) -> Any:
|
1157
|
+
raise exceptions.SnowflakeMLException(
|
1158
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1159
|
+
original_exception=AttributeError(
|
1160
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1161
|
+
"to_lightgbm()",
|
1162
|
+
"to_sklearn()"
|
1163
|
+
)
|
1164
|
+
),
|
1165
|
+
)
|
1166
|
+
|
1167
|
+
def _get_dependencies(self) -> List[str]:
|
1168
|
+
return self._deps
|
1169
|
+
|
1067
1170
|
|
1068
|
-
def
|
1171
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1069
1172
|
self._model_signature_dict = dict()
|
1070
1173
|
|
1071
1174
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1072
1175
|
|
1073
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1176
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1074
1177
|
outputs: List[BaseFeatureSpec] = []
|
1075
1178
|
if hasattr(self, "predict"):
|
1076
1179
|
# keep mypy happy
|
1077
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1180
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1078
1181
|
# For classifier, the type of predict is the same as the type of label
|
1079
|
-
if self._sklearn_object._estimator_type ==
|
1080
|
-
|
1182
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1183
|
+
# label columns is the desired type for output
|
1081
1184
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1082
1185
|
# rename the output columns
|
1083
1186
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1084
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1085
|
-
|
1086
|
-
|
1187
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1188
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1189
|
+
)
|
1087
1190
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1088
1191
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1089
|
-
# Clusterer returns int64 cluster labels.
|
1192
|
+
# Clusterer returns int64 cluster labels.
|
1090
1193
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1091
1194
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1092
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1195
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1196
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1197
|
+
)
|
1198
|
+
|
1096
1199
|
# For regressor, the type of predict is float64
|
1097
|
-
elif self._sklearn_object._estimator_type ==
|
1200
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1098
1201
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1099
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1202
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1203
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1204
|
+
)
|
1205
|
+
|
1103
1206
|
for prob_func in PROB_FUNCTIONS:
|
1104
1207
|
if hasattr(self, prob_func):
|
1105
1208
|
output_cols_prefix: str = f"{prob_func}_"
|
1106
1209
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1107
1210
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1108
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1109
|
-
|
1110
|
-
|
1211
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1212
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1213
|
+
)
|
1111
1214
|
|
1112
1215
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1113
1216
|
items = list(self._model_signature_dict.items())
|
@@ -1120,10 +1223,10 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
1120
1223
|
"""Returns model signature of current class.
|
1121
1224
|
|
1122
1225
|
Raises:
|
1123
|
-
|
1226
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1124
1227
|
|
1125
1228
|
Returns:
|
1126
|
-
Dict
|
1229
|
+
Dict with each method and its input output signature
|
1127
1230
|
"""
|
1128
1231
|
if self._model_signature_dict is None:
|
1129
1232
|
raise exceptions.SnowflakeMLException(
|
@@ -1131,35 +1234,3 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
1131
1234
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1132
1235
|
)
|
1133
1236
|
return self._model_signature_dict
|
1134
|
-
|
1135
|
-
def to_sklearn(self) -> Any:
|
1136
|
-
"""Get sklearn.ensemble.ExtraTreesRegressor object.
|
1137
|
-
"""
|
1138
|
-
if self._sklearn_object is None:
|
1139
|
-
self._sklearn_object = self._create_sklearn_object()
|
1140
|
-
return self._sklearn_object
|
1141
|
-
|
1142
|
-
def to_xgboost(self) -> Any:
|
1143
|
-
raise exceptions.SnowflakeMLException(
|
1144
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1145
|
-
original_exception=AttributeError(
|
1146
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1147
|
-
"to_xgboost()",
|
1148
|
-
"to_sklearn()"
|
1149
|
-
)
|
1150
|
-
),
|
1151
|
-
)
|
1152
|
-
|
1153
|
-
def to_lightgbm(self) -> Any:
|
1154
|
-
raise exceptions.SnowflakeMLException(
|
1155
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1156
|
-
original_exception=AttributeError(
|
1157
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1158
|
-
"to_lightgbm()",
|
1159
|
-
"to_sklearn()"
|
1160
|
-
)
|
1161
|
-
),
|
1162
|
-
)
|
1163
|
-
|
1164
|
-
def _get_dependencies(self) -> List[str]:
|
1165
|
-
return self._deps
|