snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class VotingClassifier(BaseTransformer):
|
71
64
|
r"""Soft Voting/Majority Rule classifier for unfitted estimators
|
72
65
|
For more details on this class, see [sklearn.ensemble.VotingClassifier]
|
@@ -237,12 +230,7 @@ class VotingClassifier(BaseTransformer):
|
|
237
230
|
)
|
238
231
|
return selected_cols
|
239
232
|
|
240
|
-
|
241
|
-
project=_PROJECT,
|
242
|
-
subproject=_SUBPROJECT,
|
243
|
-
custom_tags=dict([("autogen", True)]),
|
244
|
-
)
|
245
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "VotingClassifier":
|
233
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "VotingClassifier":
|
246
234
|
"""Fit the estimators
|
247
235
|
For more details on this function, see [sklearn.ensemble.VotingClassifier.fit]
|
248
236
|
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html#sklearn.ensemble.VotingClassifier.fit)
|
@@ -269,12 +257,14 @@ class VotingClassifier(BaseTransformer):
|
|
269
257
|
|
270
258
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
271
259
|
|
272
|
-
|
260
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
273
261
|
if SNOWML_SPROC_ENV in os.environ:
|
274
262
|
statement_params = telemetry.get_function_usage_statement_params(
|
275
263
|
project=_PROJECT,
|
276
264
|
subproject=_SUBPROJECT,
|
277
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
265
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
266
|
+
inspect.currentframe(), VotingClassifier.__class__.__name__
|
267
|
+
),
|
278
268
|
api_calls=[Session.call],
|
279
269
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
280
270
|
)
|
@@ -295,27 +285,24 @@ class VotingClassifier(BaseTransformer):
|
|
295
285
|
)
|
296
286
|
self._sklearn_object = model_trainer.train()
|
297
287
|
self._is_fitted = True
|
298
|
-
self.
|
288
|
+
self._generate_model_signatures(dataset)
|
299
289
|
return self
|
300
290
|
|
301
291
|
def _batch_inference_validate_snowpark(
|
302
292
|
self,
|
303
293
|
dataset: DataFrame,
|
304
294
|
inference_method: str,
|
305
|
-
) ->
|
306
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
307
|
-
return the available package that exists in the snowflake anaconda channel
|
295
|
+
) -> None:
|
296
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
308
297
|
|
309
298
|
Args:
|
310
299
|
dataset: snowpark dataframe
|
311
300
|
inference_method: the inference method such as predict, score...
|
312
|
-
|
301
|
+
|
313
302
|
Raises:
|
314
303
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
315
304
|
SnowflakeMLException: If the session is None, raise error
|
316
305
|
|
317
|
-
Returns:
|
318
|
-
A list of available package that exists in the snowflake anaconda channel
|
319
306
|
"""
|
320
307
|
if not self._is_fitted:
|
321
308
|
raise exceptions.SnowflakeMLException(
|
@@ -333,9 +320,7 @@ class VotingClassifier(BaseTransformer):
|
|
333
320
|
"Session must not specified for snowpark dataset."
|
334
321
|
),
|
335
322
|
)
|
336
|
-
|
337
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
338
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
323
|
+
|
339
324
|
|
340
325
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
341
326
|
@telemetry.send_api_usage_telemetry(
|
@@ -371,7 +356,9 @@ class VotingClassifier(BaseTransformer):
|
|
371
356
|
# when it is classifier, infer the datatype from label columns
|
372
357
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
373
358
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
374
|
-
label_cols_signatures = [
|
359
|
+
label_cols_signatures = [
|
360
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
361
|
+
]
|
375
362
|
if len(label_cols_signatures) == 0:
|
376
363
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
377
364
|
raise exceptions.SnowflakeMLException(
|
@@ -379,25 +366,23 @@ class VotingClassifier(BaseTransformer):
|
|
379
366
|
original_exception=ValueError(error_str),
|
380
367
|
)
|
381
368
|
|
382
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
383
|
-
label_cols_signatures[0].as_snowpark_type()
|
384
|
-
)
|
369
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
385
370
|
|
386
|
-
self.
|
387
|
-
|
371
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
372
|
+
self._deps = self._get_dependencies()
|
373
|
+
assert isinstance(
|
374
|
+
dataset._session, Session
|
375
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
388
376
|
|
389
377
|
transform_kwargs = dict(
|
390
|
-
session
|
391
|
-
dependencies
|
392
|
-
drop_input_cols
|
393
|
-
expected_output_cols_type
|
378
|
+
session=dataset._session,
|
379
|
+
dependencies=self._deps,
|
380
|
+
drop_input_cols=self._drop_input_cols,
|
381
|
+
expected_output_cols_type=expected_type_inferred,
|
394
382
|
)
|
395
383
|
|
396
384
|
elif isinstance(dataset, pd.DataFrame):
|
397
|
-
transform_kwargs = dict(
|
398
|
-
snowpark_input_cols = self._snowpark_cols,
|
399
|
-
drop_input_cols = self._drop_input_cols
|
400
|
-
)
|
385
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
401
386
|
|
402
387
|
transform_handlers = ModelTransformerBuilder.build(
|
403
388
|
dataset=dataset,
|
@@ -439,7 +424,7 @@ class VotingClassifier(BaseTransformer):
|
|
439
424
|
Transformed dataset.
|
440
425
|
"""
|
441
426
|
super()._check_dataset_type(dataset)
|
442
|
-
inference_method="transform"
|
427
|
+
inference_method = "transform"
|
443
428
|
|
444
429
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
445
430
|
# are specific to the type of dataset used.
|
@@ -469,24 +454,19 @@ class VotingClassifier(BaseTransformer):
|
|
469
454
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
470
455
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
471
456
|
|
472
|
-
self.
|
473
|
-
|
474
|
-
inference_method=inference_method,
|
475
|
-
)
|
457
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
458
|
+
self._deps = self._get_dependencies()
|
476
459
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
477
460
|
|
478
461
|
transform_kwargs = dict(
|
479
|
-
session
|
480
|
-
dependencies
|
481
|
-
drop_input_cols
|
482
|
-
expected_output_cols_type
|
462
|
+
session=dataset._session,
|
463
|
+
dependencies=self._deps,
|
464
|
+
drop_input_cols=self._drop_input_cols,
|
465
|
+
expected_output_cols_type=expected_dtype,
|
483
466
|
)
|
484
467
|
|
485
468
|
elif isinstance(dataset, pd.DataFrame):
|
486
|
-
transform_kwargs = dict(
|
487
|
-
snowpark_input_cols = self._snowpark_cols,
|
488
|
-
drop_input_cols = self._drop_input_cols
|
489
|
-
)
|
469
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
490
470
|
|
491
471
|
transform_handlers = ModelTransformerBuilder.build(
|
492
472
|
dataset=dataset,
|
@@ -505,7 +485,11 @@ class VotingClassifier(BaseTransformer):
|
|
505
485
|
return output_df
|
506
486
|
|
507
487
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
508
|
-
def fit_predict(
|
488
|
+
def fit_predict(
|
489
|
+
self,
|
490
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
491
|
+
output_cols_prefix: str = "fit_predict_",
|
492
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
509
493
|
""" Method not supported for this class.
|
510
494
|
|
511
495
|
|
@@ -530,22 +514,106 @@ class VotingClassifier(BaseTransformer):
|
|
530
514
|
)
|
531
515
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
532
516
|
drop_input_cols=self._drop_input_cols,
|
533
|
-
expected_output_cols_list=
|
517
|
+
expected_output_cols_list=(
|
518
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
519
|
+
),
|
534
520
|
)
|
535
521
|
self._sklearn_object = fitted_estimator
|
536
522
|
self._is_fitted = True
|
537
523
|
return output_result
|
538
524
|
|
525
|
+
|
526
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
527
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
528
|
+
""" Return class labels or probabilities for each estimator
|
529
|
+
For more details on this function, see [sklearn.ensemble.VotingClassifier.fit_transform]
|
530
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html#sklearn.ensemble.VotingClassifier.fit_transform)
|
531
|
+
|
532
|
+
|
533
|
+
Raises:
|
534
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
539
535
|
|
540
|
-
|
541
|
-
|
542
|
-
|
536
|
+
Args:
|
537
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
538
|
+
Snowpark or Pandas DataFrame.
|
539
|
+
output_cols_prefix: Prefix for the response columns
|
543
540
|
Returns:
|
544
541
|
Transformed dataset.
|
545
542
|
"""
|
546
|
-
self.
|
547
|
-
|
548
|
-
|
543
|
+
self._infer_input_output_cols(dataset)
|
544
|
+
super()._check_dataset_type(dataset)
|
545
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
546
|
+
estimator=self._sklearn_object,
|
547
|
+
dataset=dataset,
|
548
|
+
input_cols=self.input_cols,
|
549
|
+
label_cols=self.label_cols,
|
550
|
+
sample_weight_col=self.sample_weight_col,
|
551
|
+
autogenerated=self._autogenerated,
|
552
|
+
subproject=_SUBPROJECT,
|
553
|
+
)
|
554
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
555
|
+
drop_input_cols=self._drop_input_cols,
|
556
|
+
expected_output_cols_list=self.output_cols,
|
557
|
+
)
|
558
|
+
self._sklearn_object = fitted_estimator
|
559
|
+
self._is_fitted = True
|
560
|
+
return output_result
|
561
|
+
|
562
|
+
|
563
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
564
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
565
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
566
|
+
"""
|
567
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
568
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
569
|
+
if output_cols:
|
570
|
+
output_cols = [
|
571
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
572
|
+
for c in output_cols
|
573
|
+
]
|
574
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
575
|
+
output_cols = [output_cols_prefix]
|
576
|
+
elif self._sklearn_object is not None:
|
577
|
+
classes = self._sklearn_object.classes_
|
578
|
+
if isinstance(classes, numpy.ndarray):
|
579
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
580
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
581
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
582
|
+
output_cols = []
|
583
|
+
for i, cl in enumerate(classes):
|
584
|
+
# For binary classification, there is only one output column for each class
|
585
|
+
# ndarray as the two classes are complementary.
|
586
|
+
if len(cl) == 2:
|
587
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
588
|
+
else:
|
589
|
+
output_cols.extend([
|
590
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
591
|
+
])
|
592
|
+
else:
|
593
|
+
output_cols = []
|
594
|
+
|
595
|
+
# Make sure column names are valid snowflake identifiers.
|
596
|
+
assert output_cols is not None # Make MyPy happy
|
597
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
598
|
+
|
599
|
+
return rv
|
600
|
+
|
601
|
+
def _align_expected_output_names(
|
602
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
603
|
+
) -> List[str]:
|
604
|
+
# in case the inferred output column names dimension is different
|
605
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
606
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
607
|
+
output_df_columns = list(output_df_pd.columns)
|
608
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
609
|
+
if self.sample_weight_col:
|
610
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
611
|
+
# if the dimension of inferred output column names is correct; use it
|
612
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
613
|
+
return expected_output_cols_list
|
614
|
+
# otherwise, use the sklearn estimator's output
|
615
|
+
else:
|
616
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
549
617
|
|
550
618
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
551
619
|
@telemetry.send_api_usage_telemetry(
|
@@ -579,24 +647,26 @@ class VotingClassifier(BaseTransformer):
|
|
579
647
|
# are specific to the type of dataset used.
|
580
648
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
581
649
|
|
650
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
651
|
+
|
582
652
|
if isinstance(dataset, DataFrame):
|
583
|
-
self.
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
653
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
654
|
+
self._deps = self._get_dependencies()
|
655
|
+
assert isinstance(
|
656
|
+
dataset._session, Session
|
657
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
588
658
|
transform_kwargs = dict(
|
589
659
|
session=dataset._session,
|
590
660
|
dependencies=self._deps,
|
591
|
-
drop_input_cols
|
661
|
+
drop_input_cols=self._drop_input_cols,
|
592
662
|
expected_output_cols_type="float",
|
593
663
|
)
|
664
|
+
expected_output_cols = self._align_expected_output_names(
|
665
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
666
|
+
)
|
594
667
|
|
595
668
|
elif isinstance(dataset, pd.DataFrame):
|
596
|
-
transform_kwargs = dict(
|
597
|
-
snowpark_input_cols = self._snowpark_cols,
|
598
|
-
drop_input_cols = self._drop_input_cols
|
599
|
-
)
|
669
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
600
670
|
|
601
671
|
transform_handlers = ModelTransformerBuilder.build(
|
602
672
|
dataset=dataset,
|
@@ -608,7 +678,7 @@ class VotingClassifier(BaseTransformer):
|
|
608
678
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
609
679
|
inference_method=inference_method,
|
610
680
|
input_cols=self.input_cols,
|
611
|
-
expected_output_cols=
|
681
|
+
expected_output_cols=expected_output_cols,
|
612
682
|
**transform_kwargs
|
613
683
|
)
|
614
684
|
return output_df
|
@@ -640,29 +710,30 @@ class VotingClassifier(BaseTransformer):
|
|
640
710
|
Output dataset with log probability of the sample for each class in the model.
|
641
711
|
"""
|
642
712
|
super()._check_dataset_type(dataset)
|
643
|
-
inference_method="predict_log_proba"
|
713
|
+
inference_method = "predict_log_proba"
|
714
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
644
715
|
|
645
716
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
646
717
|
# are specific to the type of dataset used.
|
647
718
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
648
719
|
|
649
720
|
if isinstance(dataset, DataFrame):
|
650
|
-
self.
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
721
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
722
|
+
self._deps = self._get_dependencies()
|
723
|
+
assert isinstance(
|
724
|
+
dataset._session, Session
|
725
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
655
726
|
transform_kwargs = dict(
|
656
727
|
session=dataset._session,
|
657
728
|
dependencies=self._deps,
|
658
|
-
drop_input_cols
|
729
|
+
drop_input_cols=self._drop_input_cols,
|
659
730
|
expected_output_cols_type="float",
|
660
731
|
)
|
732
|
+
expected_output_cols = self._align_expected_output_names(
|
733
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
734
|
+
)
|
661
735
|
elif isinstance(dataset, pd.DataFrame):
|
662
|
-
transform_kwargs = dict(
|
663
|
-
snowpark_input_cols = self._snowpark_cols,
|
664
|
-
drop_input_cols = self._drop_input_cols
|
665
|
-
)
|
736
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
666
737
|
|
667
738
|
transform_handlers = ModelTransformerBuilder.build(
|
668
739
|
dataset=dataset,
|
@@ -675,7 +746,7 @@ class VotingClassifier(BaseTransformer):
|
|
675
746
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
676
747
|
inference_method=inference_method,
|
677
748
|
input_cols=self.input_cols,
|
678
|
-
expected_output_cols=
|
749
|
+
expected_output_cols=expected_output_cols,
|
679
750
|
**transform_kwargs
|
680
751
|
)
|
681
752
|
return output_df
|
@@ -701,30 +772,32 @@ class VotingClassifier(BaseTransformer):
|
|
701
772
|
Output dataset with results of the decision function for the samples in input dataset.
|
702
773
|
"""
|
703
774
|
super()._check_dataset_type(dataset)
|
704
|
-
inference_method="decision_function"
|
775
|
+
inference_method = "decision_function"
|
705
776
|
|
706
777
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
707
778
|
# are specific to the type of dataset used.
|
708
779
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
709
780
|
|
781
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
782
|
+
|
710
783
|
if isinstance(dataset, DataFrame):
|
711
|
-
self.
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
784
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
785
|
+
self._deps = self._get_dependencies()
|
786
|
+
assert isinstance(
|
787
|
+
dataset._session, Session
|
788
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
716
789
|
transform_kwargs = dict(
|
717
790
|
session=dataset._session,
|
718
791
|
dependencies=self._deps,
|
719
|
-
drop_input_cols
|
792
|
+
drop_input_cols=self._drop_input_cols,
|
720
793
|
expected_output_cols_type="float",
|
721
794
|
)
|
795
|
+
expected_output_cols = self._align_expected_output_names(
|
796
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
797
|
+
)
|
722
798
|
|
723
799
|
elif isinstance(dataset, pd.DataFrame):
|
724
|
-
transform_kwargs = dict(
|
725
|
-
snowpark_input_cols = self._snowpark_cols,
|
726
|
-
drop_input_cols = self._drop_input_cols
|
727
|
-
)
|
800
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
728
801
|
|
729
802
|
transform_handlers = ModelTransformerBuilder.build(
|
730
803
|
dataset=dataset,
|
@@ -737,7 +810,7 @@ class VotingClassifier(BaseTransformer):
|
|
737
810
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
738
811
|
inference_method=inference_method,
|
739
812
|
input_cols=self.input_cols,
|
740
|
-
expected_output_cols=
|
813
|
+
expected_output_cols=expected_output_cols,
|
741
814
|
**transform_kwargs
|
742
815
|
)
|
743
816
|
return output_df
|
@@ -766,17 +839,17 @@ class VotingClassifier(BaseTransformer):
|
|
766
839
|
Output dataset with probability of the sample for each class in the model.
|
767
840
|
"""
|
768
841
|
super()._check_dataset_type(dataset)
|
769
|
-
inference_method="score_samples"
|
842
|
+
inference_method = "score_samples"
|
770
843
|
|
771
844
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
772
845
|
# are specific to the type of dataset used.
|
773
846
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
774
847
|
|
848
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
849
|
+
|
775
850
|
if isinstance(dataset, DataFrame):
|
776
|
-
self.
|
777
|
-
|
778
|
-
inference_method=inference_method,
|
779
|
-
)
|
851
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
852
|
+
self._deps = self._get_dependencies()
|
780
853
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
781
854
|
transform_kwargs = dict(
|
782
855
|
session=dataset._session,
|
@@ -784,6 +857,9 @@ class VotingClassifier(BaseTransformer):
|
|
784
857
|
drop_input_cols = self._drop_input_cols,
|
785
858
|
expected_output_cols_type="float",
|
786
859
|
)
|
860
|
+
expected_output_cols = self._align_expected_output_names(
|
861
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
862
|
+
)
|
787
863
|
|
788
864
|
elif isinstance(dataset, pd.DataFrame):
|
789
865
|
transform_kwargs = dict(
|
@@ -802,7 +878,7 @@ class VotingClassifier(BaseTransformer):
|
|
802
878
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
803
879
|
inference_method=inference_method,
|
804
880
|
input_cols=self.input_cols,
|
805
|
-
expected_output_cols=
|
881
|
+
expected_output_cols=expected_output_cols,
|
806
882
|
**transform_kwargs
|
807
883
|
)
|
808
884
|
return output_df
|
@@ -837,17 +913,15 @@ class VotingClassifier(BaseTransformer):
|
|
837
913
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
838
914
|
|
839
915
|
if isinstance(dataset, DataFrame):
|
840
|
-
self.
|
841
|
-
|
842
|
-
inference_method="score",
|
843
|
-
)
|
916
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
917
|
+
self._deps = self._get_dependencies()
|
844
918
|
selected_cols = self._get_active_columns()
|
845
919
|
if len(selected_cols) > 0:
|
846
920
|
dataset = dataset.select(selected_cols)
|
847
921
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
848
922
|
transform_kwargs = dict(
|
849
923
|
session=dataset._session,
|
850
|
-
dependencies=
|
924
|
+
dependencies=self._deps,
|
851
925
|
score_sproc_imports=['sklearn'],
|
852
926
|
)
|
853
927
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -912,11 +986,8 @@ class VotingClassifier(BaseTransformer):
|
|
912
986
|
|
913
987
|
if isinstance(dataset, DataFrame):
|
914
988
|
|
915
|
-
self.
|
916
|
-
|
917
|
-
inference_method=inference_method,
|
918
|
-
|
919
|
-
)
|
989
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
990
|
+
self._deps = self._get_dependencies()
|
920
991
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
921
992
|
transform_kwargs = dict(
|
922
993
|
session = dataset._session,
|
@@ -949,50 +1020,84 @@ class VotingClassifier(BaseTransformer):
|
|
949
1020
|
)
|
950
1021
|
return output_df
|
951
1022
|
|
1023
|
+
|
1024
|
+
|
1025
|
+
def to_sklearn(self) -> Any:
|
1026
|
+
"""Get sklearn.ensemble.VotingClassifier object.
|
1027
|
+
"""
|
1028
|
+
if self._sklearn_object is None:
|
1029
|
+
self._sklearn_object = self._create_sklearn_object()
|
1030
|
+
return self._sklearn_object
|
1031
|
+
|
1032
|
+
def to_xgboost(self) -> Any:
|
1033
|
+
raise exceptions.SnowflakeMLException(
|
1034
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1035
|
+
original_exception=AttributeError(
|
1036
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1037
|
+
"to_xgboost()",
|
1038
|
+
"to_sklearn()"
|
1039
|
+
)
|
1040
|
+
),
|
1041
|
+
)
|
952
1042
|
|
953
|
-
def
|
1043
|
+
def to_lightgbm(self) -> Any:
|
1044
|
+
raise exceptions.SnowflakeMLException(
|
1045
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1046
|
+
original_exception=AttributeError(
|
1047
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1048
|
+
"to_lightgbm()",
|
1049
|
+
"to_sklearn()"
|
1050
|
+
)
|
1051
|
+
),
|
1052
|
+
)
|
1053
|
+
|
1054
|
+
def _get_dependencies(self) -> List[str]:
|
1055
|
+
return self._deps
|
1056
|
+
|
1057
|
+
|
1058
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
954
1059
|
self._model_signature_dict = dict()
|
955
1060
|
|
956
1061
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
957
1062
|
|
958
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1063
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
959
1064
|
outputs: List[BaseFeatureSpec] = []
|
960
1065
|
if hasattr(self, "predict"):
|
961
1066
|
# keep mypy happy
|
962
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1067
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
963
1068
|
# For classifier, the type of predict is the same as the type of label
|
964
|
-
if self._sklearn_object._estimator_type ==
|
965
|
-
|
1069
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1070
|
+
# label columns is the desired type for output
|
966
1071
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
967
1072
|
# rename the output columns
|
968
1073
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
969
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
970
|
-
|
971
|
-
|
1074
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1075
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1076
|
+
)
|
972
1077
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
973
1078
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
974
|
-
# Clusterer returns int64 cluster labels.
|
1079
|
+
# Clusterer returns int64 cluster labels.
|
975
1080
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
976
1081
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
977
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
978
|
-
|
979
|
-
|
980
|
-
|
1082
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1083
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1084
|
+
)
|
1085
|
+
|
981
1086
|
# For regressor, the type of predict is float64
|
982
|
-
elif self._sklearn_object._estimator_type ==
|
1087
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
983
1088
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
984
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
985
|
-
|
986
|
-
|
987
|
-
|
1089
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1090
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1091
|
+
)
|
1092
|
+
|
988
1093
|
for prob_func in PROB_FUNCTIONS:
|
989
1094
|
if hasattr(self, prob_func):
|
990
1095
|
output_cols_prefix: str = f"{prob_func}_"
|
991
1096
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
992
1097
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
993
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
994
|
-
|
995
|
-
|
1098
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1099
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1100
|
+
)
|
996
1101
|
|
997
1102
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
998
1103
|
items = list(self._model_signature_dict.items())
|
@@ -1005,10 +1110,10 @@ class VotingClassifier(BaseTransformer):
|
|
1005
1110
|
"""Returns model signature of current class.
|
1006
1111
|
|
1007
1112
|
Raises:
|
1008
|
-
|
1113
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1009
1114
|
|
1010
1115
|
Returns:
|
1011
|
-
Dict
|
1116
|
+
Dict with each method and its input output signature
|
1012
1117
|
"""
|
1013
1118
|
if self._model_signature_dict is None:
|
1014
1119
|
raise exceptions.SnowflakeMLException(
|
@@ -1016,35 +1121,3 @@ class VotingClassifier(BaseTransformer):
|
|
1016
1121
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1017
1122
|
)
|
1018
1123
|
return self._model_signature_dict
|
1019
|
-
|
1020
|
-
def to_sklearn(self) -> Any:
|
1021
|
-
"""Get sklearn.ensemble.VotingClassifier object.
|
1022
|
-
"""
|
1023
|
-
if self._sklearn_object is None:
|
1024
|
-
self._sklearn_object = self._create_sklearn_object()
|
1025
|
-
return self._sklearn_object
|
1026
|
-
|
1027
|
-
def to_xgboost(self) -> Any:
|
1028
|
-
raise exceptions.SnowflakeMLException(
|
1029
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1030
|
-
original_exception=AttributeError(
|
1031
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1032
|
-
"to_xgboost()",
|
1033
|
-
"to_sklearn()"
|
1034
|
-
)
|
1035
|
-
),
|
1036
|
-
)
|
1037
|
-
|
1038
|
-
def to_lightgbm(self) -> Any:
|
1039
|
-
raise exceptions.SnowflakeMLException(
|
1040
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1041
|
-
original_exception=AttributeError(
|
1042
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1043
|
-
"to_lightgbm()",
|
1044
|
-
"to_sklearn()"
|
1045
|
-
)
|
1046
|
-
),
|
1047
|
-
)
|
1048
|
-
|
1049
|
-
def _get_dependencies(self) -> List[str]:
|
1050
|
-
return self._deps
|