snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/feature_store.py +41 -17
- snowflake/ml/feature_store/feature_view.py +2 -2
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_version_impl.py +22 -7
- snowflake/ml/model/_client/ops/model_ops.py +39 -3
- snowflake/ml/model/_client/ops/service_ops.py +198 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
- snowflake/ml/model/_client/sql/service.py +85 -18
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
- snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/data/torch_dataset.py +0 -33
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -533,12 +530,23 @@ class NuSVR(BaseTransformer):
|
|
533
530
|
autogenerated=self._autogenerated,
|
534
531
|
subproject=_SUBPROJECT,
|
535
532
|
)
|
536
|
-
|
537
|
-
|
538
|
-
expected_output_cols_list=(
|
539
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
540
|
-
),
|
533
|
+
expected_output_cols = (
|
534
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
541
535
|
)
|
536
|
+
if isinstance(dataset, DataFrame):
|
537
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
538
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
539
|
+
)
|
540
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
541
|
+
drop_input_cols=self._drop_input_cols,
|
542
|
+
expected_output_cols_list=expected_output_cols,
|
543
|
+
example_output_pd_df=example_output_pd_df,
|
544
|
+
)
|
545
|
+
else:
|
546
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
547
|
+
drop_input_cols=self._drop_input_cols,
|
548
|
+
expected_output_cols_list=expected_output_cols,
|
549
|
+
)
|
542
550
|
self._sklearn_object = fitted_estimator
|
543
551
|
self._is_fitted = True
|
544
552
|
return output_result
|
@@ -617,12 +625,41 @@ class NuSVR(BaseTransformer):
|
|
617
625
|
|
618
626
|
return rv
|
619
627
|
|
620
|
-
def
|
621
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
622
|
-
) -> List[str]:
|
628
|
+
def _align_expected_output(
|
629
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
630
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
631
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
632
|
+
and output dataframe with 1 line.
|
633
|
+
If the method is fit_predict, run 2 lines of data.
|
634
|
+
"""
|
623
635
|
# in case the inferred output column names dimension is different
|
624
636
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
625
|
-
|
637
|
+
|
638
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
639
|
+
# so change the minimum of number of rows to 2
|
640
|
+
num_examples = 2
|
641
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
642
|
+
project=_PROJECT,
|
643
|
+
subproject=_SUBPROJECT,
|
644
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
645
|
+
inspect.currentframe(), NuSVR.__class__.__name__
|
646
|
+
),
|
647
|
+
api_calls=[Session.call],
|
648
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
649
|
+
)
|
650
|
+
if output_cols_prefix == "fit_predict_":
|
651
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
652
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
653
|
+
num_examples = self._sklearn_object.n_clusters
|
654
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
655
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
656
|
+
num_examples = self._sklearn_object.min_samples
|
657
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
658
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
659
|
+
num_examples = self._sklearn_object.n_neighbors
|
660
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
661
|
+
else:
|
662
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
626
663
|
|
627
664
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
628
665
|
# seen during the fit.
|
@@ -634,12 +671,14 @@ class NuSVR(BaseTransformer):
|
|
634
671
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
635
672
|
if self.sample_weight_col:
|
636
673
|
output_df_columns_set -= set(self.sample_weight_col)
|
674
|
+
|
637
675
|
# if the dimension of inferred output column names is correct; use it
|
638
676
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
639
|
-
return expected_output_cols_list
|
677
|
+
return expected_output_cols_list, output_df_pd
|
640
678
|
# otherwise, use the sklearn estimator's output
|
641
679
|
else:
|
642
|
-
|
680
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
681
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
643
682
|
|
644
683
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
645
684
|
@telemetry.send_api_usage_telemetry(
|
@@ -685,7 +724,7 @@ class NuSVR(BaseTransformer):
|
|
685
724
|
drop_input_cols=self._drop_input_cols,
|
686
725
|
expected_output_cols_type="float",
|
687
726
|
)
|
688
|
-
expected_output_cols = self.
|
727
|
+
expected_output_cols, _ = self._align_expected_output(
|
689
728
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
690
729
|
)
|
691
730
|
|
@@ -751,7 +790,7 @@ class NuSVR(BaseTransformer):
|
|
751
790
|
drop_input_cols=self._drop_input_cols,
|
752
791
|
expected_output_cols_type="float",
|
753
792
|
)
|
754
|
-
expected_output_cols = self.
|
793
|
+
expected_output_cols, _ = self._align_expected_output(
|
755
794
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
756
795
|
)
|
757
796
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -814,7 +853,7 @@ class NuSVR(BaseTransformer):
|
|
814
853
|
drop_input_cols=self._drop_input_cols,
|
815
854
|
expected_output_cols_type="float",
|
816
855
|
)
|
817
|
-
expected_output_cols = self.
|
856
|
+
expected_output_cols, _ = self._align_expected_output(
|
818
857
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
819
858
|
)
|
820
859
|
|
@@ -879,7 +918,7 @@ class NuSVR(BaseTransformer):
|
|
879
918
|
drop_input_cols = self._drop_input_cols,
|
880
919
|
expected_output_cols_type="float",
|
881
920
|
)
|
882
|
-
expected_output_cols = self.
|
921
|
+
expected_output_cols, _ = self._align_expected_output(
|
883
922
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
884
923
|
)
|
885
924
|
|
snowflake/ml/modeling/svm/svc.py
CHANGED
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -575,12 +572,23 @@ class SVC(BaseTransformer):
|
|
575
572
|
autogenerated=self._autogenerated,
|
576
573
|
subproject=_SUBPROJECT,
|
577
574
|
)
|
578
|
-
|
579
|
-
|
580
|
-
expected_output_cols_list=(
|
581
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
582
|
-
),
|
575
|
+
expected_output_cols = (
|
576
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
583
577
|
)
|
578
|
+
if isinstance(dataset, DataFrame):
|
579
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
580
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
581
|
+
)
|
582
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
583
|
+
drop_input_cols=self._drop_input_cols,
|
584
|
+
expected_output_cols_list=expected_output_cols,
|
585
|
+
example_output_pd_df=example_output_pd_df,
|
586
|
+
)
|
587
|
+
else:
|
588
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
589
|
+
drop_input_cols=self._drop_input_cols,
|
590
|
+
expected_output_cols_list=expected_output_cols,
|
591
|
+
)
|
584
592
|
self._sklearn_object = fitted_estimator
|
585
593
|
self._is_fitted = True
|
586
594
|
return output_result
|
@@ -659,12 +667,41 @@ class SVC(BaseTransformer):
|
|
659
667
|
|
660
668
|
return rv
|
661
669
|
|
662
|
-
def
|
663
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
664
|
-
) -> List[str]:
|
670
|
+
def _align_expected_output(
|
671
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
672
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
673
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
674
|
+
and output dataframe with 1 line.
|
675
|
+
If the method is fit_predict, run 2 lines of data.
|
676
|
+
"""
|
665
677
|
# in case the inferred output column names dimension is different
|
666
678
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
667
|
-
|
679
|
+
|
680
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
681
|
+
# so change the minimum of number of rows to 2
|
682
|
+
num_examples = 2
|
683
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
684
|
+
project=_PROJECT,
|
685
|
+
subproject=_SUBPROJECT,
|
686
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
687
|
+
inspect.currentframe(), SVC.__class__.__name__
|
688
|
+
),
|
689
|
+
api_calls=[Session.call],
|
690
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
691
|
+
)
|
692
|
+
if output_cols_prefix == "fit_predict_":
|
693
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
694
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
695
|
+
num_examples = self._sklearn_object.n_clusters
|
696
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
697
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
698
|
+
num_examples = self._sklearn_object.min_samples
|
699
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
700
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
701
|
+
num_examples = self._sklearn_object.n_neighbors
|
702
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
703
|
+
else:
|
704
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
668
705
|
|
669
706
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
670
707
|
# seen during the fit.
|
@@ -676,12 +713,14 @@ class SVC(BaseTransformer):
|
|
676
713
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
677
714
|
if self.sample_weight_col:
|
678
715
|
output_df_columns_set -= set(self.sample_weight_col)
|
716
|
+
|
679
717
|
# if the dimension of inferred output column names is correct; use it
|
680
718
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
681
|
-
return expected_output_cols_list
|
719
|
+
return expected_output_cols_list, output_df_pd
|
682
720
|
# otherwise, use the sklearn estimator's output
|
683
721
|
else:
|
684
|
-
|
722
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
723
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
685
724
|
|
686
725
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
687
726
|
@telemetry.send_api_usage_telemetry(
|
@@ -729,7 +768,7 @@ class SVC(BaseTransformer):
|
|
729
768
|
drop_input_cols=self._drop_input_cols,
|
730
769
|
expected_output_cols_type="float",
|
731
770
|
)
|
732
|
-
expected_output_cols = self.
|
771
|
+
expected_output_cols, _ = self._align_expected_output(
|
733
772
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
734
773
|
)
|
735
774
|
|
@@ -797,7 +836,7 @@ class SVC(BaseTransformer):
|
|
797
836
|
drop_input_cols=self._drop_input_cols,
|
798
837
|
expected_output_cols_type="float",
|
799
838
|
)
|
800
|
-
expected_output_cols = self.
|
839
|
+
expected_output_cols, _ = self._align_expected_output(
|
801
840
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
802
841
|
)
|
803
842
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -862,7 +901,7 @@ class SVC(BaseTransformer):
|
|
862
901
|
drop_input_cols=self._drop_input_cols,
|
863
902
|
expected_output_cols_type="float",
|
864
903
|
)
|
865
|
-
expected_output_cols = self.
|
904
|
+
expected_output_cols, _ = self._align_expected_output(
|
866
905
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
867
906
|
)
|
868
907
|
|
@@ -927,7 +966,7 @@ class SVC(BaseTransformer):
|
|
927
966
|
drop_input_cols = self._drop_input_cols,
|
928
967
|
expected_output_cols_type="float",
|
929
968
|
)
|
930
|
-
expected_output_cols = self.
|
969
|
+
expected_output_cols, _ = self._align_expected_output(
|
931
970
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
932
971
|
)
|
933
972
|
|
snowflake/ml/modeling/svm/svr.py
CHANGED
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -536,12 +533,23 @@ class SVR(BaseTransformer):
|
|
536
533
|
autogenerated=self._autogenerated,
|
537
534
|
subproject=_SUBPROJECT,
|
538
535
|
)
|
539
|
-
|
540
|
-
|
541
|
-
expected_output_cols_list=(
|
542
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
543
|
-
),
|
536
|
+
expected_output_cols = (
|
537
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
544
538
|
)
|
539
|
+
if isinstance(dataset, DataFrame):
|
540
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
541
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
542
|
+
)
|
543
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
544
|
+
drop_input_cols=self._drop_input_cols,
|
545
|
+
expected_output_cols_list=expected_output_cols,
|
546
|
+
example_output_pd_df=example_output_pd_df,
|
547
|
+
)
|
548
|
+
else:
|
549
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
550
|
+
drop_input_cols=self._drop_input_cols,
|
551
|
+
expected_output_cols_list=expected_output_cols,
|
552
|
+
)
|
545
553
|
self._sklearn_object = fitted_estimator
|
546
554
|
self._is_fitted = True
|
547
555
|
return output_result
|
@@ -620,12 +628,41 @@ class SVR(BaseTransformer):
|
|
620
628
|
|
621
629
|
return rv
|
622
630
|
|
623
|
-
def
|
624
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
625
|
-
) -> List[str]:
|
631
|
+
def _align_expected_output(
|
632
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
633
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
634
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
635
|
+
and output dataframe with 1 line.
|
636
|
+
If the method is fit_predict, run 2 lines of data.
|
637
|
+
"""
|
626
638
|
# in case the inferred output column names dimension is different
|
627
639
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
628
|
-
|
640
|
+
|
641
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
642
|
+
# so change the minimum of number of rows to 2
|
643
|
+
num_examples = 2
|
644
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
645
|
+
project=_PROJECT,
|
646
|
+
subproject=_SUBPROJECT,
|
647
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
648
|
+
inspect.currentframe(), SVR.__class__.__name__
|
649
|
+
),
|
650
|
+
api_calls=[Session.call],
|
651
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
652
|
+
)
|
653
|
+
if output_cols_prefix == "fit_predict_":
|
654
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
655
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
656
|
+
num_examples = self._sklearn_object.n_clusters
|
657
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
658
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
659
|
+
num_examples = self._sklearn_object.min_samples
|
660
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
661
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
662
|
+
num_examples = self._sklearn_object.n_neighbors
|
663
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
664
|
+
else:
|
665
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
629
666
|
|
630
667
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
631
668
|
# seen during the fit.
|
@@ -637,12 +674,14 @@ class SVR(BaseTransformer):
|
|
637
674
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
638
675
|
if self.sample_weight_col:
|
639
676
|
output_df_columns_set -= set(self.sample_weight_col)
|
677
|
+
|
640
678
|
# if the dimension of inferred output column names is correct; use it
|
641
679
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
642
|
-
return expected_output_cols_list
|
680
|
+
return expected_output_cols_list, output_df_pd
|
643
681
|
# otherwise, use the sklearn estimator's output
|
644
682
|
else:
|
645
|
-
|
683
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
684
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
646
685
|
|
647
686
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
648
687
|
@telemetry.send_api_usage_telemetry(
|
@@ -688,7 +727,7 @@ class SVR(BaseTransformer):
|
|
688
727
|
drop_input_cols=self._drop_input_cols,
|
689
728
|
expected_output_cols_type="float",
|
690
729
|
)
|
691
|
-
expected_output_cols = self.
|
730
|
+
expected_output_cols, _ = self._align_expected_output(
|
692
731
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
693
732
|
)
|
694
733
|
|
@@ -754,7 +793,7 @@ class SVR(BaseTransformer):
|
|
754
793
|
drop_input_cols=self._drop_input_cols,
|
755
794
|
expected_output_cols_type="float",
|
756
795
|
)
|
757
|
-
expected_output_cols = self.
|
796
|
+
expected_output_cols, _ = self._align_expected_output(
|
758
797
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
759
798
|
)
|
760
799
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -817,7 +856,7 @@ class SVR(BaseTransformer):
|
|
817
856
|
drop_input_cols=self._drop_input_cols,
|
818
857
|
expected_output_cols_type="float",
|
819
858
|
)
|
820
|
-
expected_output_cols = self.
|
859
|
+
expected_output_cols, _ = self._align_expected_output(
|
821
860
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
822
861
|
)
|
823
862
|
|
@@ -882,7 +921,7 @@ class SVR(BaseTransformer):
|
|
882
921
|
drop_input_cols = self._drop_input_cols,
|
883
922
|
expected_output_cols_type="float",
|
884
923
|
)
|
885
|
-
expected_output_cols = self.
|
924
|
+
expected_output_cols, _ = self._align_expected_output(
|
886
925
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
887
926
|
)
|
888
927
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -603,12 +600,23 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
603
600
|
autogenerated=self._autogenerated,
|
604
601
|
subproject=_SUBPROJECT,
|
605
602
|
)
|
606
|
-
|
607
|
-
|
608
|
-
expected_output_cols_list=(
|
609
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
610
|
-
),
|
603
|
+
expected_output_cols = (
|
604
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
611
605
|
)
|
606
|
+
if isinstance(dataset, DataFrame):
|
607
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
608
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
609
|
+
)
|
610
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
611
|
+
drop_input_cols=self._drop_input_cols,
|
612
|
+
expected_output_cols_list=expected_output_cols,
|
613
|
+
example_output_pd_df=example_output_pd_df,
|
614
|
+
)
|
615
|
+
else:
|
616
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
617
|
+
drop_input_cols=self._drop_input_cols,
|
618
|
+
expected_output_cols_list=expected_output_cols,
|
619
|
+
)
|
612
620
|
self._sklearn_object = fitted_estimator
|
613
621
|
self._is_fitted = True
|
614
622
|
return output_result
|
@@ -687,12 +695,41 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
687
695
|
|
688
696
|
return rv
|
689
697
|
|
690
|
-
def
|
691
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
692
|
-
) -> List[str]:
|
698
|
+
def _align_expected_output(
|
699
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
700
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
701
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
702
|
+
and output dataframe with 1 line.
|
703
|
+
If the method is fit_predict, run 2 lines of data.
|
704
|
+
"""
|
693
705
|
# in case the inferred output column names dimension is different
|
694
706
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
695
|
-
|
707
|
+
|
708
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
709
|
+
# so change the minimum of number of rows to 2
|
710
|
+
num_examples = 2
|
711
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
712
|
+
project=_PROJECT,
|
713
|
+
subproject=_SUBPROJECT,
|
714
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
715
|
+
inspect.currentframe(), DecisionTreeClassifier.__class__.__name__
|
716
|
+
),
|
717
|
+
api_calls=[Session.call],
|
718
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
719
|
+
)
|
720
|
+
if output_cols_prefix == "fit_predict_":
|
721
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
722
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
723
|
+
num_examples = self._sklearn_object.n_clusters
|
724
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
725
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
726
|
+
num_examples = self._sklearn_object.min_samples
|
727
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
728
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
729
|
+
num_examples = self._sklearn_object.n_neighbors
|
730
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
731
|
+
else:
|
732
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
696
733
|
|
697
734
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
698
735
|
# seen during the fit.
|
@@ -704,12 +741,14 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
704
741
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
705
742
|
if self.sample_weight_col:
|
706
743
|
output_df_columns_set -= set(self.sample_weight_col)
|
744
|
+
|
707
745
|
# if the dimension of inferred output column names is correct; use it
|
708
746
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
709
|
-
return expected_output_cols_list
|
747
|
+
return expected_output_cols_list, output_df_pd
|
710
748
|
# otherwise, use the sklearn estimator's output
|
711
749
|
else:
|
712
|
-
|
750
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
751
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
713
752
|
|
714
753
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
715
754
|
@telemetry.send_api_usage_telemetry(
|
@@ -757,7 +796,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
757
796
|
drop_input_cols=self._drop_input_cols,
|
758
797
|
expected_output_cols_type="float",
|
759
798
|
)
|
760
|
-
expected_output_cols = self.
|
799
|
+
expected_output_cols, _ = self._align_expected_output(
|
761
800
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
762
801
|
)
|
763
802
|
|
@@ -825,7 +864,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
825
864
|
drop_input_cols=self._drop_input_cols,
|
826
865
|
expected_output_cols_type="float",
|
827
866
|
)
|
828
|
-
expected_output_cols = self.
|
867
|
+
expected_output_cols, _ = self._align_expected_output(
|
829
868
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
830
869
|
)
|
831
870
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -888,7 +927,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
888
927
|
drop_input_cols=self._drop_input_cols,
|
889
928
|
expected_output_cols_type="float",
|
890
929
|
)
|
891
|
-
expected_output_cols = self.
|
930
|
+
expected_output_cols, _ = self._align_expected_output(
|
892
931
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
893
932
|
)
|
894
933
|
|
@@ -953,7 +992,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
953
992
|
drop_input_cols = self._drop_input_cols,
|
954
993
|
expected_output_cols_type="float",
|
955
994
|
)
|
956
|
-
expected_output_cols = self.
|
995
|
+
expected_output_cols, _ = self._align_expected_output(
|
957
996
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
958
997
|
)
|
959
998
|
|