snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -570,12 +567,23 @@ class RidgeClassifier(BaseTransformer):
|
|
570
567
|
autogenerated=self._autogenerated,
|
571
568
|
subproject=_SUBPROJECT,
|
572
569
|
)
|
573
|
-
|
574
|
-
|
575
|
-
expected_output_cols_list=(
|
576
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
577
|
-
),
|
570
|
+
expected_output_cols = (
|
571
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
578
572
|
)
|
573
|
+
if isinstance(dataset, DataFrame):
|
574
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
575
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
576
|
+
)
|
577
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
578
|
+
drop_input_cols=self._drop_input_cols,
|
579
|
+
expected_output_cols_list=expected_output_cols,
|
580
|
+
example_output_pd_df=example_output_pd_df,
|
581
|
+
)
|
582
|
+
else:
|
583
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
584
|
+
drop_input_cols=self._drop_input_cols,
|
585
|
+
expected_output_cols_list=expected_output_cols,
|
586
|
+
)
|
579
587
|
self._sklearn_object = fitted_estimator
|
580
588
|
self._is_fitted = True
|
581
589
|
return output_result
|
@@ -598,6 +606,7 @@ class RidgeClassifier(BaseTransformer):
|
|
598
606
|
"""
|
599
607
|
self._infer_input_output_cols(dataset)
|
600
608
|
super()._check_dataset_type(dataset)
|
609
|
+
|
601
610
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
602
611
|
estimator=self._sklearn_object,
|
603
612
|
dataset=dataset,
|
@@ -654,12 +663,41 @@ class RidgeClassifier(BaseTransformer):
|
|
654
663
|
|
655
664
|
return rv
|
656
665
|
|
657
|
-
def
|
658
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
659
|
-
) -> List[str]:
|
666
|
+
def _align_expected_output(
|
667
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
668
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
669
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
670
|
+
and output dataframe with 1 line.
|
671
|
+
If the method is fit_predict, run 2 lines of data.
|
672
|
+
"""
|
660
673
|
# in case the inferred output column names dimension is different
|
661
674
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
662
|
-
|
675
|
+
|
676
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
677
|
+
# so change the minimum of number of rows to 2
|
678
|
+
num_examples = 2
|
679
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
680
|
+
project=_PROJECT,
|
681
|
+
subproject=_SUBPROJECT,
|
682
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
683
|
+
inspect.currentframe(), RidgeClassifier.__class__.__name__
|
684
|
+
),
|
685
|
+
api_calls=[Session.call],
|
686
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
687
|
+
)
|
688
|
+
if output_cols_prefix == "fit_predict_":
|
689
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
690
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
691
|
+
num_examples = self._sklearn_object.n_clusters
|
692
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
693
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
694
|
+
num_examples = self._sklearn_object.min_samples
|
695
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
696
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
697
|
+
num_examples = self._sklearn_object.n_neighbors
|
698
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
699
|
+
else:
|
700
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
663
701
|
|
664
702
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
665
703
|
# seen during the fit.
|
@@ -671,12 +709,14 @@ class RidgeClassifier(BaseTransformer):
|
|
671
709
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
672
710
|
if self.sample_weight_col:
|
673
711
|
output_df_columns_set -= set(self.sample_weight_col)
|
712
|
+
|
674
713
|
# if the dimension of inferred output column names is correct; use it
|
675
714
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
676
|
-
return expected_output_cols_list
|
715
|
+
return expected_output_cols_list, output_df_pd
|
677
716
|
# otherwise, use the sklearn estimator's output
|
678
717
|
else:
|
679
|
-
|
718
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
719
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
680
720
|
|
681
721
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
682
722
|
@telemetry.send_api_usage_telemetry(
|
@@ -722,7 +762,7 @@ class RidgeClassifier(BaseTransformer):
|
|
722
762
|
drop_input_cols=self._drop_input_cols,
|
723
763
|
expected_output_cols_type="float",
|
724
764
|
)
|
725
|
-
expected_output_cols = self.
|
765
|
+
expected_output_cols, _ = self._align_expected_output(
|
726
766
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
727
767
|
)
|
728
768
|
|
@@ -788,7 +828,7 @@ class RidgeClassifier(BaseTransformer):
|
|
788
828
|
drop_input_cols=self._drop_input_cols,
|
789
829
|
expected_output_cols_type="float",
|
790
830
|
)
|
791
|
-
expected_output_cols = self.
|
831
|
+
expected_output_cols, _ = self._align_expected_output(
|
792
832
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
793
833
|
)
|
794
834
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -853,7 +893,7 @@ class RidgeClassifier(BaseTransformer):
|
|
853
893
|
drop_input_cols=self._drop_input_cols,
|
854
894
|
expected_output_cols_type="float",
|
855
895
|
)
|
856
|
-
expected_output_cols = self.
|
896
|
+
expected_output_cols, _ = self._align_expected_output(
|
857
897
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
858
898
|
)
|
859
899
|
|
@@ -918,7 +958,7 @@ class RidgeClassifier(BaseTransformer):
|
|
918
958
|
drop_input_cols = self._drop_input_cols,
|
919
959
|
expected_output_cols_type="float",
|
920
960
|
)
|
921
|
-
expected_output_cols = self.
|
961
|
+
expected_output_cols, _ = self._align_expected_output(
|
922
962
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
923
963
|
)
|
924
964
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -521,12 +518,23 @@ class RidgeClassifierCV(BaseTransformer):
|
|
521
518
|
autogenerated=self._autogenerated,
|
522
519
|
subproject=_SUBPROJECT,
|
523
520
|
)
|
524
|
-
|
525
|
-
|
526
|
-
expected_output_cols_list=(
|
527
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
528
|
-
),
|
521
|
+
expected_output_cols = (
|
522
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
529
523
|
)
|
524
|
+
if isinstance(dataset, DataFrame):
|
525
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
526
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
527
|
+
)
|
528
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
529
|
+
drop_input_cols=self._drop_input_cols,
|
530
|
+
expected_output_cols_list=expected_output_cols,
|
531
|
+
example_output_pd_df=example_output_pd_df,
|
532
|
+
)
|
533
|
+
else:
|
534
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
535
|
+
drop_input_cols=self._drop_input_cols,
|
536
|
+
expected_output_cols_list=expected_output_cols,
|
537
|
+
)
|
530
538
|
self._sklearn_object = fitted_estimator
|
531
539
|
self._is_fitted = True
|
532
540
|
return output_result
|
@@ -549,6 +557,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
549
557
|
"""
|
550
558
|
self._infer_input_output_cols(dataset)
|
551
559
|
super()._check_dataset_type(dataset)
|
560
|
+
|
552
561
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
553
562
|
estimator=self._sklearn_object,
|
554
563
|
dataset=dataset,
|
@@ -605,12 +614,41 @@ class RidgeClassifierCV(BaseTransformer):
|
|
605
614
|
|
606
615
|
return rv
|
607
616
|
|
608
|
-
def
|
609
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
610
|
-
) -> List[str]:
|
617
|
+
def _align_expected_output(
|
618
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
619
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
620
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
621
|
+
and output dataframe with 1 line.
|
622
|
+
If the method is fit_predict, run 2 lines of data.
|
623
|
+
"""
|
611
624
|
# in case the inferred output column names dimension is different
|
612
625
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
613
|
-
|
626
|
+
|
627
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
628
|
+
# so change the minimum of number of rows to 2
|
629
|
+
num_examples = 2
|
630
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
631
|
+
project=_PROJECT,
|
632
|
+
subproject=_SUBPROJECT,
|
633
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
634
|
+
inspect.currentframe(), RidgeClassifierCV.__class__.__name__
|
635
|
+
),
|
636
|
+
api_calls=[Session.call],
|
637
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
638
|
+
)
|
639
|
+
if output_cols_prefix == "fit_predict_":
|
640
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
641
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
642
|
+
num_examples = self._sklearn_object.n_clusters
|
643
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
644
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
645
|
+
num_examples = self._sklearn_object.min_samples
|
646
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
647
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
648
|
+
num_examples = self._sklearn_object.n_neighbors
|
649
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
650
|
+
else:
|
651
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
614
652
|
|
615
653
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
616
654
|
# seen during the fit.
|
@@ -622,12 +660,14 @@ class RidgeClassifierCV(BaseTransformer):
|
|
622
660
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
623
661
|
if self.sample_weight_col:
|
624
662
|
output_df_columns_set -= set(self.sample_weight_col)
|
663
|
+
|
625
664
|
# if the dimension of inferred output column names is correct; use it
|
626
665
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
627
|
-
return expected_output_cols_list
|
666
|
+
return expected_output_cols_list, output_df_pd
|
628
667
|
# otherwise, use the sklearn estimator's output
|
629
668
|
else:
|
630
|
-
|
669
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
670
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
631
671
|
|
632
672
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
633
673
|
@telemetry.send_api_usage_telemetry(
|
@@ -673,7 +713,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
673
713
|
drop_input_cols=self._drop_input_cols,
|
674
714
|
expected_output_cols_type="float",
|
675
715
|
)
|
676
|
-
expected_output_cols = self.
|
716
|
+
expected_output_cols, _ = self._align_expected_output(
|
677
717
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
678
718
|
)
|
679
719
|
|
@@ -739,7 +779,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
739
779
|
drop_input_cols=self._drop_input_cols,
|
740
780
|
expected_output_cols_type="float",
|
741
781
|
)
|
742
|
-
expected_output_cols = self.
|
782
|
+
expected_output_cols, _ = self._align_expected_output(
|
743
783
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
744
784
|
)
|
745
785
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -804,7 +844,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
804
844
|
drop_input_cols=self._drop_input_cols,
|
805
845
|
expected_output_cols_type="float",
|
806
846
|
)
|
807
|
-
expected_output_cols = self.
|
847
|
+
expected_output_cols, _ = self._align_expected_output(
|
808
848
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
809
849
|
)
|
810
850
|
|
@@ -869,7 +909,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
869
909
|
drop_input_cols = self._drop_input_cols,
|
870
910
|
expected_output_cols_type="float",
|
871
911
|
)
|
872
|
-
expected_output_cols = self.
|
912
|
+
expected_output_cols, _ = self._align_expected_output(
|
873
913
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
874
914
|
)
|
875
915
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -542,12 +539,23 @@ class RidgeCV(BaseTransformer):
|
|
542
539
|
autogenerated=self._autogenerated,
|
543
540
|
subproject=_SUBPROJECT,
|
544
541
|
)
|
545
|
-
|
546
|
-
|
547
|
-
expected_output_cols_list=(
|
548
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
549
|
-
),
|
542
|
+
expected_output_cols = (
|
543
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
550
544
|
)
|
545
|
+
if isinstance(dataset, DataFrame):
|
546
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
547
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
548
|
+
)
|
549
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
550
|
+
drop_input_cols=self._drop_input_cols,
|
551
|
+
expected_output_cols_list=expected_output_cols,
|
552
|
+
example_output_pd_df=example_output_pd_df,
|
553
|
+
)
|
554
|
+
else:
|
555
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
556
|
+
drop_input_cols=self._drop_input_cols,
|
557
|
+
expected_output_cols_list=expected_output_cols,
|
558
|
+
)
|
551
559
|
self._sklearn_object = fitted_estimator
|
552
560
|
self._is_fitted = True
|
553
561
|
return output_result
|
@@ -570,6 +578,7 @@ class RidgeCV(BaseTransformer):
|
|
570
578
|
"""
|
571
579
|
self._infer_input_output_cols(dataset)
|
572
580
|
super()._check_dataset_type(dataset)
|
581
|
+
|
573
582
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
574
583
|
estimator=self._sklearn_object,
|
575
584
|
dataset=dataset,
|
@@ -626,12 +635,41 @@ class RidgeCV(BaseTransformer):
|
|
626
635
|
|
627
636
|
return rv
|
628
637
|
|
629
|
-
def
|
630
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
631
|
-
) -> List[str]:
|
638
|
+
def _align_expected_output(
|
639
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
640
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
641
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
642
|
+
and output dataframe with 1 line.
|
643
|
+
If the method is fit_predict, run 2 lines of data.
|
644
|
+
"""
|
632
645
|
# in case the inferred output column names dimension is different
|
633
646
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
634
|
-
|
647
|
+
|
648
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
649
|
+
# so change the minimum of number of rows to 2
|
650
|
+
num_examples = 2
|
651
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
652
|
+
project=_PROJECT,
|
653
|
+
subproject=_SUBPROJECT,
|
654
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
655
|
+
inspect.currentframe(), RidgeCV.__class__.__name__
|
656
|
+
),
|
657
|
+
api_calls=[Session.call],
|
658
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
659
|
+
)
|
660
|
+
if output_cols_prefix == "fit_predict_":
|
661
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
662
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
663
|
+
num_examples = self._sklearn_object.n_clusters
|
664
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
665
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
666
|
+
num_examples = self._sklearn_object.min_samples
|
667
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
668
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
669
|
+
num_examples = self._sklearn_object.n_neighbors
|
670
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
671
|
+
else:
|
672
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
635
673
|
|
636
674
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
637
675
|
# seen during the fit.
|
@@ -643,12 +681,14 @@ class RidgeCV(BaseTransformer):
|
|
643
681
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
644
682
|
if self.sample_weight_col:
|
645
683
|
output_df_columns_set -= set(self.sample_weight_col)
|
684
|
+
|
646
685
|
# if the dimension of inferred output column names is correct; use it
|
647
686
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
648
|
-
return expected_output_cols_list
|
687
|
+
return expected_output_cols_list, output_df_pd
|
649
688
|
# otherwise, use the sklearn estimator's output
|
650
689
|
else:
|
651
|
-
|
690
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
691
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
652
692
|
|
653
693
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
654
694
|
@telemetry.send_api_usage_telemetry(
|
@@ -694,7 +734,7 @@ class RidgeCV(BaseTransformer):
|
|
694
734
|
drop_input_cols=self._drop_input_cols,
|
695
735
|
expected_output_cols_type="float",
|
696
736
|
)
|
697
|
-
expected_output_cols = self.
|
737
|
+
expected_output_cols, _ = self._align_expected_output(
|
698
738
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
699
739
|
)
|
700
740
|
|
@@ -760,7 +800,7 @@ class RidgeCV(BaseTransformer):
|
|
760
800
|
drop_input_cols=self._drop_input_cols,
|
761
801
|
expected_output_cols_type="float",
|
762
802
|
)
|
763
|
-
expected_output_cols = self.
|
803
|
+
expected_output_cols, _ = self._align_expected_output(
|
764
804
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
765
805
|
)
|
766
806
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -823,7 +863,7 @@ class RidgeCV(BaseTransformer):
|
|
823
863
|
drop_input_cols=self._drop_input_cols,
|
824
864
|
expected_output_cols_type="float",
|
825
865
|
)
|
826
|
-
expected_output_cols = self.
|
866
|
+
expected_output_cols, _ = self._align_expected_output(
|
827
867
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
828
868
|
)
|
829
869
|
|
@@ -888,7 +928,7 @@ class RidgeCV(BaseTransformer):
|
|
888
928
|
drop_input_cols = self._drop_input_cols,
|
889
929
|
expected_output_cols_type="float",
|
890
930
|
)
|
891
|
-
expected_output_cols = self.
|
931
|
+
expected_output_cols, _ = self._align_expected_output(
|
892
932
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
893
933
|
)
|
894
934
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -661,12 +658,23 @@ class SGDClassifier(BaseTransformer):
|
|
661
658
|
autogenerated=self._autogenerated,
|
662
659
|
subproject=_SUBPROJECT,
|
663
660
|
)
|
664
|
-
|
665
|
-
|
666
|
-
expected_output_cols_list=(
|
667
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
668
|
-
),
|
661
|
+
expected_output_cols = (
|
662
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
669
663
|
)
|
664
|
+
if isinstance(dataset, DataFrame):
|
665
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
666
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
667
|
+
)
|
668
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
669
|
+
drop_input_cols=self._drop_input_cols,
|
670
|
+
expected_output_cols_list=expected_output_cols,
|
671
|
+
example_output_pd_df=example_output_pd_df,
|
672
|
+
)
|
673
|
+
else:
|
674
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
675
|
+
drop_input_cols=self._drop_input_cols,
|
676
|
+
expected_output_cols_list=expected_output_cols,
|
677
|
+
)
|
670
678
|
self._sklearn_object = fitted_estimator
|
671
679
|
self._is_fitted = True
|
672
680
|
return output_result
|
@@ -689,6 +697,7 @@ class SGDClassifier(BaseTransformer):
|
|
689
697
|
"""
|
690
698
|
self._infer_input_output_cols(dataset)
|
691
699
|
super()._check_dataset_type(dataset)
|
700
|
+
|
692
701
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
693
702
|
estimator=self._sklearn_object,
|
694
703
|
dataset=dataset,
|
@@ -745,12 +754,41 @@ class SGDClassifier(BaseTransformer):
|
|
745
754
|
|
746
755
|
return rv
|
747
756
|
|
748
|
-
def
|
749
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
750
|
-
) -> List[str]:
|
757
|
+
def _align_expected_output(
|
758
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
759
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
760
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
761
|
+
and output dataframe with 1 line.
|
762
|
+
If the method is fit_predict, run 2 lines of data.
|
763
|
+
"""
|
751
764
|
# in case the inferred output column names dimension is different
|
752
765
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
753
|
-
|
766
|
+
|
767
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
768
|
+
# so change the minimum of number of rows to 2
|
769
|
+
num_examples = 2
|
770
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
771
|
+
project=_PROJECT,
|
772
|
+
subproject=_SUBPROJECT,
|
773
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
774
|
+
inspect.currentframe(), SGDClassifier.__class__.__name__
|
775
|
+
),
|
776
|
+
api_calls=[Session.call],
|
777
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
778
|
+
)
|
779
|
+
if output_cols_prefix == "fit_predict_":
|
780
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
781
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
782
|
+
num_examples = self._sklearn_object.n_clusters
|
783
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
784
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
785
|
+
num_examples = self._sklearn_object.min_samples
|
786
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
787
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
788
|
+
num_examples = self._sklearn_object.n_neighbors
|
789
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
790
|
+
else:
|
791
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
754
792
|
|
755
793
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
756
794
|
# seen during the fit.
|
@@ -762,12 +800,14 @@ class SGDClassifier(BaseTransformer):
|
|
762
800
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
763
801
|
if self.sample_weight_col:
|
764
802
|
output_df_columns_set -= set(self.sample_weight_col)
|
803
|
+
|
765
804
|
# if the dimension of inferred output column names is correct; use it
|
766
805
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
767
|
-
return expected_output_cols_list
|
806
|
+
return expected_output_cols_list, output_df_pd
|
768
807
|
# otherwise, use the sklearn estimator's output
|
769
808
|
else:
|
770
|
-
|
809
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
810
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
771
811
|
|
772
812
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
773
813
|
@telemetry.send_api_usage_telemetry(
|
@@ -815,7 +855,7 @@ class SGDClassifier(BaseTransformer):
|
|
815
855
|
drop_input_cols=self._drop_input_cols,
|
816
856
|
expected_output_cols_type="float",
|
817
857
|
)
|
818
|
-
expected_output_cols = self.
|
858
|
+
expected_output_cols, _ = self._align_expected_output(
|
819
859
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
820
860
|
)
|
821
861
|
|
@@ -883,7 +923,7 @@ class SGDClassifier(BaseTransformer):
|
|
883
923
|
drop_input_cols=self._drop_input_cols,
|
884
924
|
expected_output_cols_type="float",
|
885
925
|
)
|
886
|
-
expected_output_cols = self.
|
926
|
+
expected_output_cols, _ = self._align_expected_output(
|
887
927
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
888
928
|
)
|
889
929
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -948,7 +988,7 @@ class SGDClassifier(BaseTransformer):
|
|
948
988
|
drop_input_cols=self._drop_input_cols,
|
949
989
|
expected_output_cols_type="float",
|
950
990
|
)
|
951
|
-
expected_output_cols = self.
|
991
|
+
expected_output_cols, _ = self._align_expected_output(
|
952
992
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
953
993
|
)
|
954
994
|
|
@@ -1013,7 +1053,7 @@ class SGDClassifier(BaseTransformer):
|
|
1013
1053
|
drop_input_cols = self._drop_input_cols,
|
1014
1054
|
expected_output_cols_type="float",
|
1015
1055
|
)
|
1016
|
-
expected_output_cols = self.
|
1056
|
+
expected_output_cols, _ = self._align_expected_output(
|
1017
1057
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
1018
1058
|
)
|
1019
1059
|
|