snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -534,12 +531,23 @@ class SparsePCA(BaseTransformer):
|
|
534
531
|
autogenerated=self._autogenerated,
|
535
532
|
subproject=_SUBPROJECT,
|
536
533
|
)
|
537
|
-
|
538
|
-
|
539
|
-
expected_output_cols_list=(
|
540
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
541
|
-
),
|
534
|
+
expected_output_cols = (
|
535
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
542
536
|
)
|
537
|
+
if isinstance(dataset, DataFrame):
|
538
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
539
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
540
|
+
)
|
541
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
542
|
+
drop_input_cols=self._drop_input_cols,
|
543
|
+
expected_output_cols_list=expected_output_cols,
|
544
|
+
example_output_pd_df=example_output_pd_df,
|
545
|
+
)
|
546
|
+
else:
|
547
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
548
|
+
drop_input_cols=self._drop_input_cols,
|
549
|
+
expected_output_cols_list=expected_output_cols,
|
550
|
+
)
|
543
551
|
self._sklearn_object = fitted_estimator
|
544
552
|
self._is_fitted = True
|
545
553
|
return output_result
|
@@ -564,6 +572,7 @@ class SparsePCA(BaseTransformer):
|
|
564
572
|
"""
|
565
573
|
self._infer_input_output_cols(dataset)
|
566
574
|
super()._check_dataset_type(dataset)
|
575
|
+
|
567
576
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
568
577
|
estimator=self._sklearn_object,
|
569
578
|
dataset=dataset,
|
@@ -620,12 +629,41 @@ class SparsePCA(BaseTransformer):
|
|
620
629
|
|
621
630
|
return rv
|
622
631
|
|
623
|
-
def
|
624
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
625
|
-
) -> List[str]:
|
632
|
+
def _align_expected_output(
|
633
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
634
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
635
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
636
|
+
and output dataframe with 1 line.
|
637
|
+
If the method is fit_predict, run 2 lines of data.
|
638
|
+
"""
|
626
639
|
# in case the inferred output column names dimension is different
|
627
640
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
628
|
-
|
641
|
+
|
642
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
643
|
+
# so change the minimum of number of rows to 2
|
644
|
+
num_examples = 2
|
645
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
646
|
+
project=_PROJECT,
|
647
|
+
subproject=_SUBPROJECT,
|
648
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
649
|
+
inspect.currentframe(), SparsePCA.__class__.__name__
|
650
|
+
),
|
651
|
+
api_calls=[Session.call],
|
652
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
653
|
+
)
|
654
|
+
if output_cols_prefix == "fit_predict_":
|
655
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
656
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
657
|
+
num_examples = self._sklearn_object.n_clusters
|
658
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
659
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
660
|
+
num_examples = self._sklearn_object.min_samples
|
661
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
662
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
663
|
+
num_examples = self._sklearn_object.n_neighbors
|
664
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
665
|
+
else:
|
666
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
629
667
|
|
630
668
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
631
669
|
# seen during the fit.
|
@@ -637,12 +675,14 @@ class SparsePCA(BaseTransformer):
|
|
637
675
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
638
676
|
if self.sample_weight_col:
|
639
677
|
output_df_columns_set -= set(self.sample_weight_col)
|
678
|
+
|
640
679
|
# if the dimension of inferred output column names is correct; use it
|
641
680
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
642
|
-
return expected_output_cols_list
|
681
|
+
return expected_output_cols_list, output_df_pd
|
643
682
|
# otherwise, use the sklearn estimator's output
|
644
683
|
else:
|
645
|
-
|
684
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
685
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
646
686
|
|
647
687
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
648
688
|
@telemetry.send_api_usage_telemetry(
|
@@ -688,7 +728,7 @@ class SparsePCA(BaseTransformer):
|
|
688
728
|
drop_input_cols=self._drop_input_cols,
|
689
729
|
expected_output_cols_type="float",
|
690
730
|
)
|
691
|
-
expected_output_cols = self.
|
731
|
+
expected_output_cols, _ = self._align_expected_output(
|
692
732
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
693
733
|
)
|
694
734
|
|
@@ -754,7 +794,7 @@ class SparsePCA(BaseTransformer):
|
|
754
794
|
drop_input_cols=self._drop_input_cols,
|
755
795
|
expected_output_cols_type="float",
|
756
796
|
)
|
757
|
-
expected_output_cols = self.
|
797
|
+
expected_output_cols, _ = self._align_expected_output(
|
758
798
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
759
799
|
)
|
760
800
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -817,7 +857,7 @@ class SparsePCA(BaseTransformer):
|
|
817
857
|
drop_input_cols=self._drop_input_cols,
|
818
858
|
expected_output_cols_type="float",
|
819
859
|
)
|
820
|
-
expected_output_cols = self.
|
860
|
+
expected_output_cols, _ = self._align_expected_output(
|
821
861
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
822
862
|
)
|
823
863
|
|
@@ -882,7 +922,7 @@ class SparsePCA(BaseTransformer):
|
|
882
922
|
drop_input_cols = self._drop_input_cols,
|
883
923
|
expected_output_cols_type="float",
|
884
924
|
)
|
885
|
-
expected_output_cols = self.
|
925
|
+
expected_output_cols, _ = self._align_expected_output(
|
886
926
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
887
927
|
)
|
888
928
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -515,12 +512,23 @@ class TruncatedSVD(BaseTransformer):
|
|
515
512
|
autogenerated=self._autogenerated,
|
516
513
|
subproject=_SUBPROJECT,
|
517
514
|
)
|
518
|
-
|
519
|
-
|
520
|
-
expected_output_cols_list=(
|
521
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
522
|
-
),
|
515
|
+
expected_output_cols = (
|
516
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
523
517
|
)
|
518
|
+
if isinstance(dataset, DataFrame):
|
519
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
520
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
521
|
+
)
|
522
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
523
|
+
drop_input_cols=self._drop_input_cols,
|
524
|
+
expected_output_cols_list=expected_output_cols,
|
525
|
+
example_output_pd_df=example_output_pd_df,
|
526
|
+
)
|
527
|
+
else:
|
528
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
529
|
+
drop_input_cols=self._drop_input_cols,
|
530
|
+
expected_output_cols_list=expected_output_cols,
|
531
|
+
)
|
524
532
|
self._sklearn_object = fitted_estimator
|
525
533
|
self._is_fitted = True
|
526
534
|
return output_result
|
@@ -545,6 +553,7 @@ class TruncatedSVD(BaseTransformer):
|
|
545
553
|
"""
|
546
554
|
self._infer_input_output_cols(dataset)
|
547
555
|
super()._check_dataset_type(dataset)
|
556
|
+
|
548
557
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
549
558
|
estimator=self._sklearn_object,
|
550
559
|
dataset=dataset,
|
@@ -601,12 +610,41 @@ class TruncatedSVD(BaseTransformer):
|
|
601
610
|
|
602
611
|
return rv
|
603
612
|
|
604
|
-
def
|
605
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
606
|
-
) -> List[str]:
|
613
|
+
def _align_expected_output(
|
614
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
615
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
616
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
617
|
+
and output dataframe with 1 line.
|
618
|
+
If the method is fit_predict, run 2 lines of data.
|
619
|
+
"""
|
607
620
|
# in case the inferred output column names dimension is different
|
608
621
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
609
|
-
|
622
|
+
|
623
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
624
|
+
# so change the minimum of number of rows to 2
|
625
|
+
num_examples = 2
|
626
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
627
|
+
project=_PROJECT,
|
628
|
+
subproject=_SUBPROJECT,
|
629
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
630
|
+
inspect.currentframe(), TruncatedSVD.__class__.__name__
|
631
|
+
),
|
632
|
+
api_calls=[Session.call],
|
633
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
634
|
+
)
|
635
|
+
if output_cols_prefix == "fit_predict_":
|
636
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
637
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
638
|
+
num_examples = self._sklearn_object.n_clusters
|
639
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
640
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
641
|
+
num_examples = self._sklearn_object.min_samples
|
642
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
643
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
644
|
+
num_examples = self._sklearn_object.n_neighbors
|
645
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
646
|
+
else:
|
647
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
610
648
|
|
611
649
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
612
650
|
# seen during the fit.
|
@@ -618,12 +656,14 @@ class TruncatedSVD(BaseTransformer):
|
|
618
656
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
619
657
|
if self.sample_weight_col:
|
620
658
|
output_df_columns_set -= set(self.sample_weight_col)
|
659
|
+
|
621
660
|
# if the dimension of inferred output column names is correct; use it
|
622
661
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
623
|
-
return expected_output_cols_list
|
662
|
+
return expected_output_cols_list, output_df_pd
|
624
663
|
# otherwise, use the sklearn estimator's output
|
625
664
|
else:
|
626
|
-
|
665
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
666
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
627
667
|
|
628
668
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
629
669
|
@telemetry.send_api_usage_telemetry(
|
@@ -669,7 +709,7 @@ class TruncatedSVD(BaseTransformer):
|
|
669
709
|
drop_input_cols=self._drop_input_cols,
|
670
710
|
expected_output_cols_type="float",
|
671
711
|
)
|
672
|
-
expected_output_cols = self.
|
712
|
+
expected_output_cols, _ = self._align_expected_output(
|
673
713
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
674
714
|
)
|
675
715
|
|
@@ -735,7 +775,7 @@ class TruncatedSVD(BaseTransformer):
|
|
735
775
|
drop_input_cols=self._drop_input_cols,
|
736
776
|
expected_output_cols_type="float",
|
737
777
|
)
|
738
|
-
expected_output_cols = self.
|
778
|
+
expected_output_cols, _ = self._align_expected_output(
|
739
779
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
740
780
|
)
|
741
781
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -798,7 +838,7 @@ class TruncatedSVD(BaseTransformer):
|
|
798
838
|
drop_input_cols=self._drop_input_cols,
|
799
839
|
expected_output_cols_type="float",
|
800
840
|
)
|
801
|
-
expected_output_cols = self.
|
841
|
+
expected_output_cols, _ = self._align_expected_output(
|
802
842
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
803
843
|
)
|
804
844
|
|
@@ -863,7 +903,7 @@ class TruncatedSVD(BaseTransformer):
|
|
863
903
|
drop_input_cols = self._drop_input_cols,
|
864
904
|
expected_output_cols_type="float",
|
865
905
|
)
|
866
|
-
expected_output_cols = self.
|
906
|
+
expected_output_cols, _ = self._align_expected_output(
|
867
907
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
868
908
|
)
|
869
909
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -532,12 +529,23 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
532
529
|
autogenerated=self._autogenerated,
|
533
530
|
subproject=_SUBPROJECT,
|
534
531
|
)
|
535
|
-
|
536
|
-
|
537
|
-
expected_output_cols_list=(
|
538
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
539
|
-
),
|
532
|
+
expected_output_cols = (
|
533
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
540
534
|
)
|
535
|
+
if isinstance(dataset, DataFrame):
|
536
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
537
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
538
|
+
)
|
539
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
540
|
+
drop_input_cols=self._drop_input_cols,
|
541
|
+
expected_output_cols_list=expected_output_cols,
|
542
|
+
example_output_pd_df=example_output_pd_df,
|
543
|
+
)
|
544
|
+
else:
|
545
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
546
|
+
drop_input_cols=self._drop_input_cols,
|
547
|
+
expected_output_cols_list=expected_output_cols,
|
548
|
+
)
|
541
549
|
self._sklearn_object = fitted_estimator
|
542
550
|
self._is_fitted = True
|
543
551
|
return output_result
|
@@ -562,6 +570,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
562
570
|
"""
|
563
571
|
self._infer_input_output_cols(dataset)
|
564
572
|
super()._check_dataset_type(dataset)
|
573
|
+
|
565
574
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
566
575
|
estimator=self._sklearn_object,
|
567
576
|
dataset=dataset,
|
@@ -618,12 +627,41 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
618
627
|
|
619
628
|
return rv
|
620
629
|
|
621
|
-
def
|
622
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
623
|
-
) -> List[str]:
|
630
|
+
def _align_expected_output(
|
631
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
632
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
633
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
634
|
+
and output dataframe with 1 line.
|
635
|
+
If the method is fit_predict, run 2 lines of data.
|
636
|
+
"""
|
624
637
|
# in case the inferred output column names dimension is different
|
625
638
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
626
|
-
|
639
|
+
|
640
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
641
|
+
# so change the minimum of number of rows to 2
|
642
|
+
num_examples = 2
|
643
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
644
|
+
project=_PROJECT,
|
645
|
+
subproject=_SUBPROJECT,
|
646
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
647
|
+
inspect.currentframe(), LinearDiscriminantAnalysis.__class__.__name__
|
648
|
+
),
|
649
|
+
api_calls=[Session.call],
|
650
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
651
|
+
)
|
652
|
+
if output_cols_prefix == "fit_predict_":
|
653
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
654
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
655
|
+
num_examples = self._sklearn_object.n_clusters
|
656
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
657
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
658
|
+
num_examples = self._sklearn_object.min_samples
|
659
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
660
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
661
|
+
num_examples = self._sklearn_object.n_neighbors
|
662
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
663
|
+
else:
|
664
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
627
665
|
|
628
666
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
629
667
|
# seen during the fit.
|
@@ -635,12 +673,14 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
635
673
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
636
674
|
if self.sample_weight_col:
|
637
675
|
output_df_columns_set -= set(self.sample_weight_col)
|
676
|
+
|
638
677
|
# if the dimension of inferred output column names is correct; use it
|
639
678
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
640
|
-
return expected_output_cols_list
|
679
|
+
return expected_output_cols_list, output_df_pd
|
641
680
|
# otherwise, use the sklearn estimator's output
|
642
681
|
else:
|
643
|
-
|
682
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
683
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
644
684
|
|
645
685
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
646
686
|
@telemetry.send_api_usage_telemetry(
|
@@ -688,7 +728,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
688
728
|
drop_input_cols=self._drop_input_cols,
|
689
729
|
expected_output_cols_type="float",
|
690
730
|
)
|
691
|
-
expected_output_cols = self.
|
731
|
+
expected_output_cols, _ = self._align_expected_output(
|
692
732
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
693
733
|
)
|
694
734
|
|
@@ -756,7 +796,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
756
796
|
drop_input_cols=self._drop_input_cols,
|
757
797
|
expected_output_cols_type="float",
|
758
798
|
)
|
759
|
-
expected_output_cols = self.
|
799
|
+
expected_output_cols, _ = self._align_expected_output(
|
760
800
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
761
801
|
)
|
762
802
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -821,7 +861,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
821
861
|
drop_input_cols=self._drop_input_cols,
|
822
862
|
expected_output_cols_type="float",
|
823
863
|
)
|
824
|
-
expected_output_cols = self.
|
864
|
+
expected_output_cols, _ = self._align_expected_output(
|
825
865
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
826
866
|
)
|
827
867
|
|
@@ -886,7 +926,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
886
926
|
drop_input_cols = self._drop_input_cols,
|
887
927
|
expected_output_cols_type="float",
|
888
928
|
)
|
889
|
-
expected_output_cols = self.
|
929
|
+
expected_output_cols, _ = self._align_expected_output(
|
890
930
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
891
931
|
)
|
892
932
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -492,12 +489,23 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
492
489
|
autogenerated=self._autogenerated,
|
493
490
|
subproject=_SUBPROJECT,
|
494
491
|
)
|
495
|
-
|
496
|
-
|
497
|
-
expected_output_cols_list=(
|
498
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
499
|
-
),
|
492
|
+
expected_output_cols = (
|
493
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
500
494
|
)
|
495
|
+
if isinstance(dataset, DataFrame):
|
496
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
497
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
498
|
+
)
|
499
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
500
|
+
drop_input_cols=self._drop_input_cols,
|
501
|
+
expected_output_cols_list=expected_output_cols,
|
502
|
+
example_output_pd_df=example_output_pd_df,
|
503
|
+
)
|
504
|
+
else:
|
505
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
506
|
+
drop_input_cols=self._drop_input_cols,
|
507
|
+
expected_output_cols_list=expected_output_cols,
|
508
|
+
)
|
501
509
|
self._sklearn_object = fitted_estimator
|
502
510
|
self._is_fitted = True
|
503
511
|
return output_result
|
@@ -520,6 +528,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
520
528
|
"""
|
521
529
|
self._infer_input_output_cols(dataset)
|
522
530
|
super()._check_dataset_type(dataset)
|
531
|
+
|
523
532
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
524
533
|
estimator=self._sklearn_object,
|
525
534
|
dataset=dataset,
|
@@ -576,12 +585,41 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
576
585
|
|
577
586
|
return rv
|
578
587
|
|
579
|
-
def
|
580
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
581
|
-
) -> List[str]:
|
588
|
+
def _align_expected_output(
|
589
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
590
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
591
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
592
|
+
and output dataframe with 1 line.
|
593
|
+
If the method is fit_predict, run 2 lines of data.
|
594
|
+
"""
|
582
595
|
# in case the inferred output column names dimension is different
|
583
596
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
584
|
-
|
597
|
+
|
598
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
599
|
+
# so change the minimum of number of rows to 2
|
600
|
+
num_examples = 2
|
601
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
602
|
+
project=_PROJECT,
|
603
|
+
subproject=_SUBPROJECT,
|
604
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
605
|
+
inspect.currentframe(), QuadraticDiscriminantAnalysis.__class__.__name__
|
606
|
+
),
|
607
|
+
api_calls=[Session.call],
|
608
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
609
|
+
)
|
610
|
+
if output_cols_prefix == "fit_predict_":
|
611
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
612
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
613
|
+
num_examples = self._sklearn_object.n_clusters
|
614
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
615
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
616
|
+
num_examples = self._sklearn_object.min_samples
|
617
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
618
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
619
|
+
num_examples = self._sklearn_object.n_neighbors
|
620
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
621
|
+
else:
|
622
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
585
623
|
|
586
624
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
587
625
|
# seen during the fit.
|
@@ -593,12 +631,14 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
593
631
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
594
632
|
if self.sample_weight_col:
|
595
633
|
output_df_columns_set -= set(self.sample_weight_col)
|
634
|
+
|
596
635
|
# if the dimension of inferred output column names is correct; use it
|
597
636
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
598
|
-
return expected_output_cols_list
|
637
|
+
return expected_output_cols_list, output_df_pd
|
599
638
|
# otherwise, use the sklearn estimator's output
|
600
639
|
else:
|
601
|
-
|
640
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
641
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
602
642
|
|
603
643
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
604
644
|
@telemetry.send_api_usage_telemetry(
|
@@ -646,7 +686,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
646
686
|
drop_input_cols=self._drop_input_cols,
|
647
687
|
expected_output_cols_type="float",
|
648
688
|
)
|
649
|
-
expected_output_cols = self.
|
689
|
+
expected_output_cols, _ = self._align_expected_output(
|
650
690
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
651
691
|
)
|
652
692
|
|
@@ -714,7 +754,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
714
754
|
drop_input_cols=self._drop_input_cols,
|
715
755
|
expected_output_cols_type="float",
|
716
756
|
)
|
717
|
-
expected_output_cols = self.
|
757
|
+
expected_output_cols, _ = self._align_expected_output(
|
718
758
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
719
759
|
)
|
720
760
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -779,7 +819,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
779
819
|
drop_input_cols=self._drop_input_cols,
|
780
820
|
expected_output_cols_type="float",
|
781
821
|
)
|
782
|
-
expected_output_cols = self.
|
822
|
+
expected_output_cols, _ = self._align_expected_output(
|
783
823
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
784
824
|
)
|
785
825
|
|
@@ -844,7 +884,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
844
884
|
drop_input_cols = self._drop_input_cols,
|
845
885
|
expected_output_cols_type="float",
|
846
886
|
)
|
847
|
-
expected_output_cols = self.
|
887
|
+
expected_output_cols, _ = self._align_expected_output(
|
848
888
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
849
889
|
)
|
850
890
|
|