snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
snowflake/ml/modeling/svm/svr.py
CHANGED
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -536,12 +533,23 @@ class SVR(BaseTransformer):
|
|
536
533
|
autogenerated=self._autogenerated,
|
537
534
|
subproject=_SUBPROJECT,
|
538
535
|
)
|
539
|
-
|
540
|
-
|
541
|
-
expected_output_cols_list=(
|
542
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
543
|
-
),
|
536
|
+
expected_output_cols = (
|
537
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
544
538
|
)
|
539
|
+
if isinstance(dataset, DataFrame):
|
540
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
541
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
542
|
+
)
|
543
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
544
|
+
drop_input_cols=self._drop_input_cols,
|
545
|
+
expected_output_cols_list=expected_output_cols,
|
546
|
+
example_output_pd_df=example_output_pd_df,
|
547
|
+
)
|
548
|
+
else:
|
549
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
550
|
+
drop_input_cols=self._drop_input_cols,
|
551
|
+
expected_output_cols_list=expected_output_cols,
|
552
|
+
)
|
545
553
|
self._sklearn_object = fitted_estimator
|
546
554
|
self._is_fitted = True
|
547
555
|
return output_result
|
@@ -564,6 +572,7 @@ class SVR(BaseTransformer):
|
|
564
572
|
"""
|
565
573
|
self._infer_input_output_cols(dataset)
|
566
574
|
super()._check_dataset_type(dataset)
|
575
|
+
|
567
576
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
568
577
|
estimator=self._sklearn_object,
|
569
578
|
dataset=dataset,
|
@@ -620,12 +629,41 @@ class SVR(BaseTransformer):
|
|
620
629
|
|
621
630
|
return rv
|
622
631
|
|
623
|
-
def
|
624
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
625
|
-
) -> List[str]:
|
632
|
+
def _align_expected_output(
|
633
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
634
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
635
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
636
|
+
and output dataframe with 1 line.
|
637
|
+
If the method is fit_predict, run 2 lines of data.
|
638
|
+
"""
|
626
639
|
# in case the inferred output column names dimension is different
|
627
640
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
628
|
-
|
641
|
+
|
642
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
643
|
+
# so change the minimum of number of rows to 2
|
644
|
+
num_examples = 2
|
645
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
646
|
+
project=_PROJECT,
|
647
|
+
subproject=_SUBPROJECT,
|
648
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
649
|
+
inspect.currentframe(), SVR.__class__.__name__
|
650
|
+
),
|
651
|
+
api_calls=[Session.call],
|
652
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
653
|
+
)
|
654
|
+
if output_cols_prefix == "fit_predict_":
|
655
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
656
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
657
|
+
num_examples = self._sklearn_object.n_clusters
|
658
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
659
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
660
|
+
num_examples = self._sklearn_object.min_samples
|
661
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
662
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
663
|
+
num_examples = self._sklearn_object.n_neighbors
|
664
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
665
|
+
else:
|
666
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
629
667
|
|
630
668
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
631
669
|
# seen during the fit.
|
@@ -637,12 +675,14 @@ class SVR(BaseTransformer):
|
|
637
675
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
638
676
|
if self.sample_weight_col:
|
639
677
|
output_df_columns_set -= set(self.sample_weight_col)
|
678
|
+
|
640
679
|
# if the dimension of inferred output column names is correct; use it
|
641
680
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
642
|
-
return expected_output_cols_list
|
681
|
+
return expected_output_cols_list, output_df_pd
|
643
682
|
# otherwise, use the sklearn estimator's output
|
644
683
|
else:
|
645
|
-
|
684
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
685
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
646
686
|
|
647
687
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
648
688
|
@telemetry.send_api_usage_telemetry(
|
@@ -688,7 +728,7 @@ class SVR(BaseTransformer):
|
|
688
728
|
drop_input_cols=self._drop_input_cols,
|
689
729
|
expected_output_cols_type="float",
|
690
730
|
)
|
691
|
-
expected_output_cols = self.
|
731
|
+
expected_output_cols, _ = self._align_expected_output(
|
692
732
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
693
733
|
)
|
694
734
|
|
@@ -754,7 +794,7 @@ class SVR(BaseTransformer):
|
|
754
794
|
drop_input_cols=self._drop_input_cols,
|
755
795
|
expected_output_cols_type="float",
|
756
796
|
)
|
757
|
-
expected_output_cols = self.
|
797
|
+
expected_output_cols, _ = self._align_expected_output(
|
758
798
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
759
799
|
)
|
760
800
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -817,7 +857,7 @@ class SVR(BaseTransformer):
|
|
817
857
|
drop_input_cols=self._drop_input_cols,
|
818
858
|
expected_output_cols_type="float",
|
819
859
|
)
|
820
|
-
expected_output_cols = self.
|
860
|
+
expected_output_cols, _ = self._align_expected_output(
|
821
861
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
822
862
|
)
|
823
863
|
|
@@ -882,7 +922,7 @@ class SVR(BaseTransformer):
|
|
882
922
|
drop_input_cols = self._drop_input_cols,
|
883
923
|
expected_output_cols_type="float",
|
884
924
|
)
|
885
|
-
expected_output_cols = self.
|
925
|
+
expected_output_cols, _ = self._align_expected_output(
|
886
926
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
887
927
|
)
|
888
928
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -603,12 +600,23 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
603
600
|
autogenerated=self._autogenerated,
|
604
601
|
subproject=_SUBPROJECT,
|
605
602
|
)
|
606
|
-
|
607
|
-
|
608
|
-
expected_output_cols_list=(
|
609
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
610
|
-
),
|
603
|
+
expected_output_cols = (
|
604
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
611
605
|
)
|
606
|
+
if isinstance(dataset, DataFrame):
|
607
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
608
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
609
|
+
)
|
610
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
611
|
+
drop_input_cols=self._drop_input_cols,
|
612
|
+
expected_output_cols_list=expected_output_cols,
|
613
|
+
example_output_pd_df=example_output_pd_df,
|
614
|
+
)
|
615
|
+
else:
|
616
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
617
|
+
drop_input_cols=self._drop_input_cols,
|
618
|
+
expected_output_cols_list=expected_output_cols,
|
619
|
+
)
|
612
620
|
self._sklearn_object = fitted_estimator
|
613
621
|
self._is_fitted = True
|
614
622
|
return output_result
|
@@ -631,6 +639,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
631
639
|
"""
|
632
640
|
self._infer_input_output_cols(dataset)
|
633
641
|
super()._check_dataset_type(dataset)
|
642
|
+
|
634
643
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
635
644
|
estimator=self._sklearn_object,
|
636
645
|
dataset=dataset,
|
@@ -687,12 +696,41 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
687
696
|
|
688
697
|
return rv
|
689
698
|
|
690
|
-
def
|
691
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
692
|
-
) -> List[str]:
|
699
|
+
def _align_expected_output(
|
700
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
701
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
702
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
703
|
+
and output dataframe with 1 line.
|
704
|
+
If the method is fit_predict, run 2 lines of data.
|
705
|
+
"""
|
693
706
|
# in case the inferred output column names dimension is different
|
694
707
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
695
|
-
|
708
|
+
|
709
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
710
|
+
# so change the minimum of number of rows to 2
|
711
|
+
num_examples = 2
|
712
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
713
|
+
project=_PROJECT,
|
714
|
+
subproject=_SUBPROJECT,
|
715
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
716
|
+
inspect.currentframe(), DecisionTreeClassifier.__class__.__name__
|
717
|
+
),
|
718
|
+
api_calls=[Session.call],
|
719
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
720
|
+
)
|
721
|
+
if output_cols_prefix == "fit_predict_":
|
722
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
723
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
724
|
+
num_examples = self._sklearn_object.n_clusters
|
725
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
726
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
727
|
+
num_examples = self._sklearn_object.min_samples
|
728
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
729
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
730
|
+
num_examples = self._sklearn_object.n_neighbors
|
731
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
732
|
+
else:
|
733
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
696
734
|
|
697
735
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
698
736
|
# seen during the fit.
|
@@ -704,12 +742,14 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
704
742
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
705
743
|
if self.sample_weight_col:
|
706
744
|
output_df_columns_set -= set(self.sample_weight_col)
|
745
|
+
|
707
746
|
# if the dimension of inferred output column names is correct; use it
|
708
747
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
709
|
-
return expected_output_cols_list
|
748
|
+
return expected_output_cols_list, output_df_pd
|
710
749
|
# otherwise, use the sklearn estimator's output
|
711
750
|
else:
|
712
|
-
|
751
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
752
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
713
753
|
|
714
754
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
715
755
|
@telemetry.send_api_usage_telemetry(
|
@@ -757,7 +797,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
757
797
|
drop_input_cols=self._drop_input_cols,
|
758
798
|
expected_output_cols_type="float",
|
759
799
|
)
|
760
|
-
expected_output_cols = self.
|
800
|
+
expected_output_cols, _ = self._align_expected_output(
|
761
801
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
762
802
|
)
|
763
803
|
|
@@ -825,7 +865,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
825
865
|
drop_input_cols=self._drop_input_cols,
|
826
866
|
expected_output_cols_type="float",
|
827
867
|
)
|
828
|
-
expected_output_cols = self.
|
868
|
+
expected_output_cols, _ = self._align_expected_output(
|
829
869
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
830
870
|
)
|
831
871
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -888,7 +928,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
888
928
|
drop_input_cols=self._drop_input_cols,
|
889
929
|
expected_output_cols_type="float",
|
890
930
|
)
|
891
|
-
expected_output_cols = self.
|
931
|
+
expected_output_cols, _ = self._align_expected_output(
|
892
932
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
893
933
|
)
|
894
934
|
|
@@ -953,7 +993,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
953
993
|
drop_input_cols = self._drop_input_cols,
|
954
994
|
expected_output_cols_type="float",
|
955
995
|
)
|
956
|
-
expected_output_cols = self.
|
996
|
+
expected_output_cols, _ = self._align_expected_output(
|
957
997
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
958
998
|
)
|
959
999
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -585,12 +582,23 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
585
582
|
autogenerated=self._autogenerated,
|
586
583
|
subproject=_SUBPROJECT,
|
587
584
|
)
|
588
|
-
|
589
|
-
|
590
|
-
expected_output_cols_list=(
|
591
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
592
|
-
),
|
585
|
+
expected_output_cols = (
|
586
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
593
587
|
)
|
588
|
+
if isinstance(dataset, DataFrame):
|
589
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
590
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
591
|
+
)
|
592
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
593
|
+
drop_input_cols=self._drop_input_cols,
|
594
|
+
expected_output_cols_list=expected_output_cols,
|
595
|
+
example_output_pd_df=example_output_pd_df,
|
596
|
+
)
|
597
|
+
else:
|
598
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
599
|
+
drop_input_cols=self._drop_input_cols,
|
600
|
+
expected_output_cols_list=expected_output_cols,
|
601
|
+
)
|
594
602
|
self._sklearn_object = fitted_estimator
|
595
603
|
self._is_fitted = True
|
596
604
|
return output_result
|
@@ -613,6 +621,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
613
621
|
"""
|
614
622
|
self._infer_input_output_cols(dataset)
|
615
623
|
super()._check_dataset_type(dataset)
|
624
|
+
|
616
625
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
617
626
|
estimator=self._sklearn_object,
|
618
627
|
dataset=dataset,
|
@@ -669,12 +678,41 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
669
678
|
|
670
679
|
return rv
|
671
680
|
|
672
|
-
def
|
673
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
674
|
-
) -> List[str]:
|
681
|
+
def _align_expected_output(
|
682
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
683
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
684
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
685
|
+
and output dataframe with 1 line.
|
686
|
+
If the method is fit_predict, run 2 lines of data.
|
687
|
+
"""
|
675
688
|
# in case the inferred output column names dimension is different
|
676
689
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
677
|
-
|
690
|
+
|
691
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
692
|
+
# so change the minimum of number of rows to 2
|
693
|
+
num_examples = 2
|
694
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
695
|
+
project=_PROJECT,
|
696
|
+
subproject=_SUBPROJECT,
|
697
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
698
|
+
inspect.currentframe(), DecisionTreeRegressor.__class__.__name__
|
699
|
+
),
|
700
|
+
api_calls=[Session.call],
|
701
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
702
|
+
)
|
703
|
+
if output_cols_prefix == "fit_predict_":
|
704
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
705
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
706
|
+
num_examples = self._sklearn_object.n_clusters
|
707
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
708
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
709
|
+
num_examples = self._sklearn_object.min_samples
|
710
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
711
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
712
|
+
num_examples = self._sklearn_object.n_neighbors
|
713
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
714
|
+
else:
|
715
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
678
716
|
|
679
717
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
680
718
|
# seen during the fit.
|
@@ -686,12 +724,14 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
686
724
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
687
725
|
if self.sample_weight_col:
|
688
726
|
output_df_columns_set -= set(self.sample_weight_col)
|
727
|
+
|
689
728
|
# if the dimension of inferred output column names is correct; use it
|
690
729
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
691
|
-
return expected_output_cols_list
|
730
|
+
return expected_output_cols_list, output_df_pd
|
692
731
|
# otherwise, use the sklearn estimator's output
|
693
732
|
else:
|
694
|
-
|
733
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
734
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
695
735
|
|
696
736
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
697
737
|
@telemetry.send_api_usage_telemetry(
|
@@ -737,7 +777,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
737
777
|
drop_input_cols=self._drop_input_cols,
|
738
778
|
expected_output_cols_type="float",
|
739
779
|
)
|
740
|
-
expected_output_cols = self.
|
780
|
+
expected_output_cols, _ = self._align_expected_output(
|
741
781
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
742
782
|
)
|
743
783
|
|
@@ -803,7 +843,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
803
843
|
drop_input_cols=self._drop_input_cols,
|
804
844
|
expected_output_cols_type="float",
|
805
845
|
)
|
806
|
-
expected_output_cols = self.
|
846
|
+
expected_output_cols, _ = self._align_expected_output(
|
807
847
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
808
848
|
)
|
809
849
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -866,7 +906,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
866
906
|
drop_input_cols=self._drop_input_cols,
|
867
907
|
expected_output_cols_type="float",
|
868
908
|
)
|
869
|
-
expected_output_cols = self.
|
909
|
+
expected_output_cols, _ = self._align_expected_output(
|
870
910
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
871
911
|
)
|
872
912
|
|
@@ -931,7 +971,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
931
971
|
drop_input_cols = self._drop_input_cols,
|
932
972
|
expected_output_cols_type="float",
|
933
973
|
)
|
934
|
-
expected_output_cols = self.
|
974
|
+
expected_output_cols, _ = self._align_expected_output(
|
935
975
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
936
976
|
)
|
937
977
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -595,12 +592,23 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
595
592
|
autogenerated=self._autogenerated,
|
596
593
|
subproject=_SUBPROJECT,
|
597
594
|
)
|
598
|
-
|
599
|
-
|
600
|
-
expected_output_cols_list=(
|
601
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
602
|
-
),
|
595
|
+
expected_output_cols = (
|
596
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
603
597
|
)
|
598
|
+
if isinstance(dataset, DataFrame):
|
599
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
600
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
601
|
+
)
|
602
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
603
|
+
drop_input_cols=self._drop_input_cols,
|
604
|
+
expected_output_cols_list=expected_output_cols,
|
605
|
+
example_output_pd_df=example_output_pd_df,
|
606
|
+
)
|
607
|
+
else:
|
608
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
609
|
+
drop_input_cols=self._drop_input_cols,
|
610
|
+
expected_output_cols_list=expected_output_cols,
|
611
|
+
)
|
604
612
|
self._sklearn_object = fitted_estimator
|
605
613
|
self._is_fitted = True
|
606
614
|
return output_result
|
@@ -623,6 +631,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
623
631
|
"""
|
624
632
|
self._infer_input_output_cols(dataset)
|
625
633
|
super()._check_dataset_type(dataset)
|
634
|
+
|
626
635
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
627
636
|
estimator=self._sklearn_object,
|
628
637
|
dataset=dataset,
|
@@ -679,12 +688,41 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
679
688
|
|
680
689
|
return rv
|
681
690
|
|
682
|
-
def
|
683
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
684
|
-
) -> List[str]:
|
691
|
+
def _align_expected_output(
|
692
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
693
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
694
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
695
|
+
and output dataframe with 1 line.
|
696
|
+
If the method is fit_predict, run 2 lines of data.
|
697
|
+
"""
|
685
698
|
# in case the inferred output column names dimension is different
|
686
699
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
687
|
-
|
700
|
+
|
701
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
702
|
+
# so change the minimum of number of rows to 2
|
703
|
+
num_examples = 2
|
704
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
705
|
+
project=_PROJECT,
|
706
|
+
subproject=_SUBPROJECT,
|
707
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
708
|
+
inspect.currentframe(), ExtraTreeClassifier.__class__.__name__
|
709
|
+
),
|
710
|
+
api_calls=[Session.call],
|
711
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
712
|
+
)
|
713
|
+
if output_cols_prefix == "fit_predict_":
|
714
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
715
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
716
|
+
num_examples = self._sklearn_object.n_clusters
|
717
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
718
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
719
|
+
num_examples = self._sklearn_object.min_samples
|
720
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
721
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
722
|
+
num_examples = self._sklearn_object.n_neighbors
|
723
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
724
|
+
else:
|
725
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
688
726
|
|
689
727
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
690
728
|
# seen during the fit.
|
@@ -696,12 +734,14 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
696
734
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
697
735
|
if self.sample_weight_col:
|
698
736
|
output_df_columns_set -= set(self.sample_weight_col)
|
737
|
+
|
699
738
|
# if the dimension of inferred output column names is correct; use it
|
700
739
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
701
|
-
return expected_output_cols_list
|
740
|
+
return expected_output_cols_list, output_df_pd
|
702
741
|
# otherwise, use the sklearn estimator's output
|
703
742
|
else:
|
704
|
-
|
743
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
744
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
705
745
|
|
706
746
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
707
747
|
@telemetry.send_api_usage_telemetry(
|
@@ -749,7 +789,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
749
789
|
drop_input_cols=self._drop_input_cols,
|
750
790
|
expected_output_cols_type="float",
|
751
791
|
)
|
752
|
-
expected_output_cols = self.
|
792
|
+
expected_output_cols, _ = self._align_expected_output(
|
753
793
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
754
794
|
)
|
755
795
|
|
@@ -817,7 +857,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
817
857
|
drop_input_cols=self._drop_input_cols,
|
818
858
|
expected_output_cols_type="float",
|
819
859
|
)
|
820
|
-
expected_output_cols = self.
|
860
|
+
expected_output_cols, _ = self._align_expected_output(
|
821
861
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
822
862
|
)
|
823
863
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -880,7 +920,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
880
920
|
drop_input_cols=self._drop_input_cols,
|
881
921
|
expected_output_cols_type="float",
|
882
922
|
)
|
883
|
-
expected_output_cols = self.
|
923
|
+
expected_output_cols, _ = self._align_expected_output(
|
884
924
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
885
925
|
)
|
886
926
|
|
@@ -945,7 +985,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
945
985
|
drop_input_cols = self._drop_input_cols,
|
946
986
|
expected_output_cols_type="float",
|
947
987
|
)
|
948
|
-
expected_output_cols = self.
|
988
|
+
expected_output_cols, _ = self._align_expected_output(
|
949
989
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
950
990
|
)
|
951
991
|
|