snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
snowflake/ml/modeling/svm/svr.py
CHANGED
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -536,12 +533,23 @@ class SVR(BaseTransformer):
|
|
536
533
|
autogenerated=self._autogenerated,
|
537
534
|
subproject=_SUBPROJECT,
|
538
535
|
)
|
539
|
-
|
540
|
-
|
541
|
-
expected_output_cols_list=(
|
542
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
543
|
-
),
|
536
|
+
expected_output_cols = (
|
537
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
544
538
|
)
|
539
|
+
if isinstance(dataset, DataFrame):
|
540
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
541
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
542
|
+
)
|
543
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
544
|
+
drop_input_cols=self._drop_input_cols,
|
545
|
+
expected_output_cols_list=expected_output_cols,
|
546
|
+
example_output_pd_df=example_output_pd_df,
|
547
|
+
)
|
548
|
+
else:
|
549
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
550
|
+
drop_input_cols=self._drop_input_cols,
|
551
|
+
expected_output_cols_list=expected_output_cols,
|
552
|
+
)
|
545
553
|
self._sklearn_object = fitted_estimator
|
546
554
|
self._is_fitted = True
|
547
555
|
return output_result
|
@@ -620,12 +628,41 @@ class SVR(BaseTransformer):
|
|
620
628
|
|
621
629
|
return rv
|
622
630
|
|
623
|
-
def
|
624
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
625
|
-
) -> List[str]:
|
631
|
+
def _align_expected_output(
|
632
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
633
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
634
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
635
|
+
and output dataframe with 1 line.
|
636
|
+
If the method is fit_predict, run 2 lines of data.
|
637
|
+
"""
|
626
638
|
# in case the inferred output column names dimension is different
|
627
639
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
628
|
-
|
640
|
+
|
641
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
642
|
+
# so change the minimum of number of rows to 2
|
643
|
+
num_examples = 2
|
644
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
645
|
+
project=_PROJECT,
|
646
|
+
subproject=_SUBPROJECT,
|
647
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
648
|
+
inspect.currentframe(), SVR.__class__.__name__
|
649
|
+
),
|
650
|
+
api_calls=[Session.call],
|
651
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
652
|
+
)
|
653
|
+
if output_cols_prefix == "fit_predict_":
|
654
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
655
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
656
|
+
num_examples = self._sklearn_object.n_clusters
|
657
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
658
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
659
|
+
num_examples = self._sklearn_object.min_samples
|
660
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
661
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
662
|
+
num_examples = self._sklearn_object.n_neighbors
|
663
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
664
|
+
else:
|
665
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
629
666
|
|
630
667
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
631
668
|
# seen during the fit.
|
@@ -637,12 +674,14 @@ class SVR(BaseTransformer):
|
|
637
674
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
638
675
|
if self.sample_weight_col:
|
639
676
|
output_df_columns_set -= set(self.sample_weight_col)
|
677
|
+
|
640
678
|
# if the dimension of inferred output column names is correct; use it
|
641
679
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
642
|
-
return expected_output_cols_list
|
680
|
+
return expected_output_cols_list, output_df_pd
|
643
681
|
# otherwise, use the sklearn estimator's output
|
644
682
|
else:
|
645
|
-
|
683
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
684
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
646
685
|
|
647
686
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
648
687
|
@telemetry.send_api_usage_telemetry(
|
@@ -688,7 +727,7 @@ class SVR(BaseTransformer):
|
|
688
727
|
drop_input_cols=self._drop_input_cols,
|
689
728
|
expected_output_cols_type="float",
|
690
729
|
)
|
691
|
-
expected_output_cols = self.
|
730
|
+
expected_output_cols, _ = self._align_expected_output(
|
692
731
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
693
732
|
)
|
694
733
|
|
@@ -754,7 +793,7 @@ class SVR(BaseTransformer):
|
|
754
793
|
drop_input_cols=self._drop_input_cols,
|
755
794
|
expected_output_cols_type="float",
|
756
795
|
)
|
757
|
-
expected_output_cols = self.
|
796
|
+
expected_output_cols, _ = self._align_expected_output(
|
758
797
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
759
798
|
)
|
760
799
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -817,7 +856,7 @@ class SVR(BaseTransformer):
|
|
817
856
|
drop_input_cols=self._drop_input_cols,
|
818
857
|
expected_output_cols_type="float",
|
819
858
|
)
|
820
|
-
expected_output_cols = self.
|
859
|
+
expected_output_cols, _ = self._align_expected_output(
|
821
860
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
822
861
|
)
|
823
862
|
|
@@ -882,7 +921,7 @@ class SVR(BaseTransformer):
|
|
882
921
|
drop_input_cols = self._drop_input_cols,
|
883
922
|
expected_output_cols_type="float",
|
884
923
|
)
|
885
|
-
expected_output_cols = self.
|
924
|
+
expected_output_cols, _ = self._align_expected_output(
|
886
925
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
887
926
|
)
|
888
927
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -603,12 +600,23 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
603
600
|
autogenerated=self._autogenerated,
|
604
601
|
subproject=_SUBPROJECT,
|
605
602
|
)
|
606
|
-
|
607
|
-
|
608
|
-
expected_output_cols_list=(
|
609
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
610
|
-
),
|
603
|
+
expected_output_cols = (
|
604
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
611
605
|
)
|
606
|
+
if isinstance(dataset, DataFrame):
|
607
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
608
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
609
|
+
)
|
610
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
611
|
+
drop_input_cols=self._drop_input_cols,
|
612
|
+
expected_output_cols_list=expected_output_cols,
|
613
|
+
example_output_pd_df=example_output_pd_df,
|
614
|
+
)
|
615
|
+
else:
|
616
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
617
|
+
drop_input_cols=self._drop_input_cols,
|
618
|
+
expected_output_cols_list=expected_output_cols,
|
619
|
+
)
|
612
620
|
self._sklearn_object = fitted_estimator
|
613
621
|
self._is_fitted = True
|
614
622
|
return output_result
|
@@ -687,12 +695,41 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
687
695
|
|
688
696
|
return rv
|
689
697
|
|
690
|
-
def
|
691
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
692
|
-
) -> List[str]:
|
698
|
+
def _align_expected_output(
|
699
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
700
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
701
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
702
|
+
and output dataframe with 1 line.
|
703
|
+
If the method is fit_predict, run 2 lines of data.
|
704
|
+
"""
|
693
705
|
# in case the inferred output column names dimension is different
|
694
706
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
695
|
-
|
707
|
+
|
708
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
709
|
+
# so change the minimum of number of rows to 2
|
710
|
+
num_examples = 2
|
711
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
712
|
+
project=_PROJECT,
|
713
|
+
subproject=_SUBPROJECT,
|
714
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
715
|
+
inspect.currentframe(), DecisionTreeClassifier.__class__.__name__
|
716
|
+
),
|
717
|
+
api_calls=[Session.call],
|
718
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
719
|
+
)
|
720
|
+
if output_cols_prefix == "fit_predict_":
|
721
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
722
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
723
|
+
num_examples = self._sklearn_object.n_clusters
|
724
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
725
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
726
|
+
num_examples = self._sklearn_object.min_samples
|
727
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
728
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
729
|
+
num_examples = self._sklearn_object.n_neighbors
|
730
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
731
|
+
else:
|
732
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
696
733
|
|
697
734
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
698
735
|
# seen during the fit.
|
@@ -704,12 +741,14 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
704
741
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
705
742
|
if self.sample_weight_col:
|
706
743
|
output_df_columns_set -= set(self.sample_weight_col)
|
744
|
+
|
707
745
|
# if the dimension of inferred output column names is correct; use it
|
708
746
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
709
|
-
return expected_output_cols_list
|
747
|
+
return expected_output_cols_list, output_df_pd
|
710
748
|
# otherwise, use the sklearn estimator's output
|
711
749
|
else:
|
712
|
-
|
750
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
751
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
713
752
|
|
714
753
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
715
754
|
@telemetry.send_api_usage_telemetry(
|
@@ -757,7 +796,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
757
796
|
drop_input_cols=self._drop_input_cols,
|
758
797
|
expected_output_cols_type="float",
|
759
798
|
)
|
760
|
-
expected_output_cols = self.
|
799
|
+
expected_output_cols, _ = self._align_expected_output(
|
761
800
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
762
801
|
)
|
763
802
|
|
@@ -825,7 +864,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
825
864
|
drop_input_cols=self._drop_input_cols,
|
826
865
|
expected_output_cols_type="float",
|
827
866
|
)
|
828
|
-
expected_output_cols = self.
|
867
|
+
expected_output_cols, _ = self._align_expected_output(
|
829
868
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
830
869
|
)
|
831
870
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -888,7 +927,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
888
927
|
drop_input_cols=self._drop_input_cols,
|
889
928
|
expected_output_cols_type="float",
|
890
929
|
)
|
891
|
-
expected_output_cols = self.
|
930
|
+
expected_output_cols, _ = self._align_expected_output(
|
892
931
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
893
932
|
)
|
894
933
|
|
@@ -953,7 +992,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
953
992
|
drop_input_cols = self._drop_input_cols,
|
954
993
|
expected_output_cols_type="float",
|
955
994
|
)
|
956
|
-
expected_output_cols = self.
|
995
|
+
expected_output_cols, _ = self._align_expected_output(
|
957
996
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
958
997
|
)
|
959
998
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -585,12 +582,23 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
585
582
|
autogenerated=self._autogenerated,
|
586
583
|
subproject=_SUBPROJECT,
|
587
584
|
)
|
588
|
-
|
589
|
-
|
590
|
-
expected_output_cols_list=(
|
591
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
592
|
-
),
|
585
|
+
expected_output_cols = (
|
586
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
593
587
|
)
|
588
|
+
if isinstance(dataset, DataFrame):
|
589
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
590
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
591
|
+
)
|
592
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
593
|
+
drop_input_cols=self._drop_input_cols,
|
594
|
+
expected_output_cols_list=expected_output_cols,
|
595
|
+
example_output_pd_df=example_output_pd_df,
|
596
|
+
)
|
597
|
+
else:
|
598
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
599
|
+
drop_input_cols=self._drop_input_cols,
|
600
|
+
expected_output_cols_list=expected_output_cols,
|
601
|
+
)
|
594
602
|
self._sklearn_object = fitted_estimator
|
595
603
|
self._is_fitted = True
|
596
604
|
return output_result
|
@@ -669,12 +677,41 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
669
677
|
|
670
678
|
return rv
|
671
679
|
|
672
|
-
def
|
673
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
674
|
-
) -> List[str]:
|
680
|
+
def _align_expected_output(
|
681
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
682
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
683
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
684
|
+
and output dataframe with 1 line.
|
685
|
+
If the method is fit_predict, run 2 lines of data.
|
686
|
+
"""
|
675
687
|
# in case the inferred output column names dimension is different
|
676
688
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
677
|
-
|
689
|
+
|
690
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
691
|
+
# so change the minimum of number of rows to 2
|
692
|
+
num_examples = 2
|
693
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
694
|
+
project=_PROJECT,
|
695
|
+
subproject=_SUBPROJECT,
|
696
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
697
|
+
inspect.currentframe(), DecisionTreeRegressor.__class__.__name__
|
698
|
+
),
|
699
|
+
api_calls=[Session.call],
|
700
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
701
|
+
)
|
702
|
+
if output_cols_prefix == "fit_predict_":
|
703
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
704
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
705
|
+
num_examples = self._sklearn_object.n_clusters
|
706
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
707
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
708
|
+
num_examples = self._sklearn_object.min_samples
|
709
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
710
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
711
|
+
num_examples = self._sklearn_object.n_neighbors
|
712
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
713
|
+
else:
|
714
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
678
715
|
|
679
716
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
680
717
|
# seen during the fit.
|
@@ -686,12 +723,14 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
686
723
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
687
724
|
if self.sample_weight_col:
|
688
725
|
output_df_columns_set -= set(self.sample_weight_col)
|
726
|
+
|
689
727
|
# if the dimension of inferred output column names is correct; use it
|
690
728
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
691
|
-
return expected_output_cols_list
|
729
|
+
return expected_output_cols_list, output_df_pd
|
692
730
|
# otherwise, use the sklearn estimator's output
|
693
731
|
else:
|
694
|
-
|
732
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
733
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
695
734
|
|
696
735
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
697
736
|
@telemetry.send_api_usage_telemetry(
|
@@ -737,7 +776,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
737
776
|
drop_input_cols=self._drop_input_cols,
|
738
777
|
expected_output_cols_type="float",
|
739
778
|
)
|
740
|
-
expected_output_cols = self.
|
779
|
+
expected_output_cols, _ = self._align_expected_output(
|
741
780
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
742
781
|
)
|
743
782
|
|
@@ -803,7 +842,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
803
842
|
drop_input_cols=self._drop_input_cols,
|
804
843
|
expected_output_cols_type="float",
|
805
844
|
)
|
806
|
-
expected_output_cols = self.
|
845
|
+
expected_output_cols, _ = self._align_expected_output(
|
807
846
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
808
847
|
)
|
809
848
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -866,7 +905,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
866
905
|
drop_input_cols=self._drop_input_cols,
|
867
906
|
expected_output_cols_type="float",
|
868
907
|
)
|
869
|
-
expected_output_cols = self.
|
908
|
+
expected_output_cols, _ = self._align_expected_output(
|
870
909
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
871
910
|
)
|
872
911
|
|
@@ -931,7 +970,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
931
970
|
drop_input_cols = self._drop_input_cols,
|
932
971
|
expected_output_cols_type="float",
|
933
972
|
)
|
934
|
-
expected_output_cols = self.
|
973
|
+
expected_output_cols, _ = self._align_expected_output(
|
935
974
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
936
975
|
)
|
937
976
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -595,12 +592,23 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
595
592
|
autogenerated=self._autogenerated,
|
596
593
|
subproject=_SUBPROJECT,
|
597
594
|
)
|
598
|
-
|
599
|
-
|
600
|
-
expected_output_cols_list=(
|
601
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
602
|
-
),
|
595
|
+
expected_output_cols = (
|
596
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
603
597
|
)
|
598
|
+
if isinstance(dataset, DataFrame):
|
599
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
600
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
601
|
+
)
|
602
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
603
|
+
drop_input_cols=self._drop_input_cols,
|
604
|
+
expected_output_cols_list=expected_output_cols,
|
605
|
+
example_output_pd_df=example_output_pd_df,
|
606
|
+
)
|
607
|
+
else:
|
608
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
609
|
+
drop_input_cols=self._drop_input_cols,
|
610
|
+
expected_output_cols_list=expected_output_cols,
|
611
|
+
)
|
604
612
|
self._sklearn_object = fitted_estimator
|
605
613
|
self._is_fitted = True
|
606
614
|
return output_result
|
@@ -679,12 +687,41 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
679
687
|
|
680
688
|
return rv
|
681
689
|
|
682
|
-
def
|
683
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
684
|
-
) -> List[str]:
|
690
|
+
def _align_expected_output(
|
691
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
692
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
693
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
694
|
+
and output dataframe with 1 line.
|
695
|
+
If the method is fit_predict, run 2 lines of data.
|
696
|
+
"""
|
685
697
|
# in case the inferred output column names dimension is different
|
686
698
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
687
|
-
|
699
|
+
|
700
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
701
|
+
# so change the minimum of number of rows to 2
|
702
|
+
num_examples = 2
|
703
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
704
|
+
project=_PROJECT,
|
705
|
+
subproject=_SUBPROJECT,
|
706
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
707
|
+
inspect.currentframe(), ExtraTreeClassifier.__class__.__name__
|
708
|
+
),
|
709
|
+
api_calls=[Session.call],
|
710
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
711
|
+
)
|
712
|
+
if output_cols_prefix == "fit_predict_":
|
713
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
714
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
715
|
+
num_examples = self._sklearn_object.n_clusters
|
716
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
717
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
718
|
+
num_examples = self._sklearn_object.min_samples
|
719
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
720
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
721
|
+
num_examples = self._sklearn_object.n_neighbors
|
722
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
723
|
+
else:
|
724
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
688
725
|
|
689
726
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
690
727
|
# seen during the fit.
|
@@ -696,12 +733,14 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
696
733
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
697
734
|
if self.sample_weight_col:
|
698
735
|
output_df_columns_set -= set(self.sample_weight_col)
|
736
|
+
|
699
737
|
# if the dimension of inferred output column names is correct; use it
|
700
738
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
701
|
-
return expected_output_cols_list
|
739
|
+
return expected_output_cols_list, output_df_pd
|
702
740
|
# otherwise, use the sklearn estimator's output
|
703
741
|
else:
|
704
|
-
|
742
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
743
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
705
744
|
|
706
745
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
707
746
|
@telemetry.send_api_usage_telemetry(
|
@@ -749,7 +788,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
749
788
|
drop_input_cols=self._drop_input_cols,
|
750
789
|
expected_output_cols_type="float",
|
751
790
|
)
|
752
|
-
expected_output_cols = self.
|
791
|
+
expected_output_cols, _ = self._align_expected_output(
|
753
792
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
754
793
|
)
|
755
794
|
|
@@ -817,7 +856,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
817
856
|
drop_input_cols=self._drop_input_cols,
|
818
857
|
expected_output_cols_type="float",
|
819
858
|
)
|
820
|
-
expected_output_cols = self.
|
859
|
+
expected_output_cols, _ = self._align_expected_output(
|
821
860
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
822
861
|
)
|
823
862
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -880,7 +919,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
880
919
|
drop_input_cols=self._drop_input_cols,
|
881
920
|
expected_output_cols_type="float",
|
882
921
|
)
|
883
|
-
expected_output_cols = self.
|
922
|
+
expected_output_cols, _ = self._align_expected_output(
|
884
923
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
885
924
|
)
|
886
925
|
|
@@ -945,7 +984,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
945
984
|
drop_input_cols = self._drop_input_cols,
|
946
985
|
expected_output_cols_type="float",
|
947
986
|
)
|
948
|
-
expected_output_cols = self.
|
987
|
+
expected_output_cols, _ = self._align_expected_output(
|
949
988
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
950
989
|
)
|
951
990
|
|