snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -495,12 +492,23 @@ class ComplementNB(BaseTransformer):
|
|
495
492
|
autogenerated=self._autogenerated,
|
496
493
|
subproject=_SUBPROJECT,
|
497
494
|
)
|
498
|
-
|
499
|
-
|
500
|
-
expected_output_cols_list=(
|
501
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
502
|
-
),
|
495
|
+
expected_output_cols = (
|
496
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
503
497
|
)
|
498
|
+
if isinstance(dataset, DataFrame):
|
499
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
500
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
501
|
+
)
|
502
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
504
|
+
expected_output_cols_list=expected_output_cols,
|
505
|
+
example_output_pd_df=example_output_pd_df,
|
506
|
+
)
|
507
|
+
else:
|
508
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
+
drop_input_cols=self._drop_input_cols,
|
510
|
+
expected_output_cols_list=expected_output_cols,
|
511
|
+
)
|
504
512
|
self._sklearn_object = fitted_estimator
|
505
513
|
self._is_fitted = True
|
506
514
|
return output_result
|
@@ -523,6 +531,7 @@ class ComplementNB(BaseTransformer):
|
|
523
531
|
"""
|
524
532
|
self._infer_input_output_cols(dataset)
|
525
533
|
super()._check_dataset_type(dataset)
|
534
|
+
|
526
535
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
527
536
|
estimator=self._sklearn_object,
|
528
537
|
dataset=dataset,
|
@@ -579,12 +588,41 @@ class ComplementNB(BaseTransformer):
|
|
579
588
|
|
580
589
|
return rv
|
581
590
|
|
582
|
-
def
|
583
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
584
|
-
) -> List[str]:
|
591
|
+
def _align_expected_output(
|
592
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
593
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
594
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
595
|
+
and output dataframe with 1 line.
|
596
|
+
If the method is fit_predict, run 2 lines of data.
|
597
|
+
"""
|
585
598
|
# in case the inferred output column names dimension is different
|
586
599
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
587
|
-
|
600
|
+
|
601
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
602
|
+
# so change the minimum of number of rows to 2
|
603
|
+
num_examples = 2
|
604
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
605
|
+
project=_PROJECT,
|
606
|
+
subproject=_SUBPROJECT,
|
607
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
608
|
+
inspect.currentframe(), ComplementNB.__class__.__name__
|
609
|
+
),
|
610
|
+
api_calls=[Session.call],
|
611
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
612
|
+
)
|
613
|
+
if output_cols_prefix == "fit_predict_":
|
614
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
615
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
616
|
+
num_examples = self._sklearn_object.n_clusters
|
617
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
618
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
619
|
+
num_examples = self._sklearn_object.min_samples
|
620
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
621
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
622
|
+
num_examples = self._sklearn_object.n_neighbors
|
623
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
624
|
+
else:
|
625
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
588
626
|
|
589
627
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
590
628
|
# seen during the fit.
|
@@ -596,12 +634,14 @@ class ComplementNB(BaseTransformer):
|
|
596
634
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
597
635
|
if self.sample_weight_col:
|
598
636
|
output_df_columns_set -= set(self.sample_weight_col)
|
637
|
+
|
599
638
|
# if the dimension of inferred output column names is correct; use it
|
600
639
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
601
|
-
return expected_output_cols_list
|
640
|
+
return expected_output_cols_list, output_df_pd
|
602
641
|
# otherwise, use the sklearn estimator's output
|
603
642
|
else:
|
604
|
-
|
643
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
644
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
605
645
|
|
606
646
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
607
647
|
@telemetry.send_api_usage_telemetry(
|
@@ -649,7 +689,7 @@ class ComplementNB(BaseTransformer):
|
|
649
689
|
drop_input_cols=self._drop_input_cols,
|
650
690
|
expected_output_cols_type="float",
|
651
691
|
)
|
652
|
-
expected_output_cols = self.
|
692
|
+
expected_output_cols, _ = self._align_expected_output(
|
653
693
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
654
694
|
)
|
655
695
|
|
@@ -717,7 +757,7 @@ class ComplementNB(BaseTransformer):
|
|
717
757
|
drop_input_cols=self._drop_input_cols,
|
718
758
|
expected_output_cols_type="float",
|
719
759
|
)
|
720
|
-
expected_output_cols = self.
|
760
|
+
expected_output_cols, _ = self._align_expected_output(
|
721
761
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
722
762
|
)
|
723
763
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -780,7 +820,7 @@ class ComplementNB(BaseTransformer):
|
|
780
820
|
drop_input_cols=self._drop_input_cols,
|
781
821
|
expected_output_cols_type="float",
|
782
822
|
)
|
783
|
-
expected_output_cols = self.
|
823
|
+
expected_output_cols, _ = self._align_expected_output(
|
784
824
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
785
825
|
)
|
786
826
|
|
@@ -845,7 +885,7 @@ class ComplementNB(BaseTransformer):
|
|
845
885
|
drop_input_cols = self._drop_input_cols,
|
846
886
|
expected_output_cols_type="float",
|
847
887
|
)
|
848
|
-
expected_output_cols = self.
|
888
|
+
expected_output_cols, _ = self._align_expected_output(
|
849
889
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
850
890
|
)
|
851
891
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -476,12 +473,23 @@ class GaussianNB(BaseTransformer):
|
|
476
473
|
autogenerated=self._autogenerated,
|
477
474
|
subproject=_SUBPROJECT,
|
478
475
|
)
|
479
|
-
|
480
|
-
|
481
|
-
expected_output_cols_list=(
|
482
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
483
|
-
),
|
476
|
+
expected_output_cols = (
|
477
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
484
478
|
)
|
479
|
+
if isinstance(dataset, DataFrame):
|
480
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
481
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
482
|
+
)
|
483
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
484
|
+
drop_input_cols=self._drop_input_cols,
|
485
|
+
expected_output_cols_list=expected_output_cols,
|
486
|
+
example_output_pd_df=example_output_pd_df,
|
487
|
+
)
|
488
|
+
else:
|
489
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
490
|
+
drop_input_cols=self._drop_input_cols,
|
491
|
+
expected_output_cols_list=expected_output_cols,
|
492
|
+
)
|
485
493
|
self._sklearn_object = fitted_estimator
|
486
494
|
self._is_fitted = True
|
487
495
|
return output_result
|
@@ -504,6 +512,7 @@ class GaussianNB(BaseTransformer):
|
|
504
512
|
"""
|
505
513
|
self._infer_input_output_cols(dataset)
|
506
514
|
super()._check_dataset_type(dataset)
|
515
|
+
|
507
516
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
508
517
|
estimator=self._sklearn_object,
|
509
518
|
dataset=dataset,
|
@@ -560,12 +569,41 @@ class GaussianNB(BaseTransformer):
|
|
560
569
|
|
561
570
|
return rv
|
562
571
|
|
563
|
-
def
|
564
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
565
|
-
) -> List[str]:
|
572
|
+
def _align_expected_output(
|
573
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
574
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
575
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
576
|
+
and output dataframe with 1 line.
|
577
|
+
If the method is fit_predict, run 2 lines of data.
|
578
|
+
"""
|
566
579
|
# in case the inferred output column names dimension is different
|
567
580
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
568
|
-
|
581
|
+
|
582
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
583
|
+
# so change the minimum of number of rows to 2
|
584
|
+
num_examples = 2
|
585
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
586
|
+
project=_PROJECT,
|
587
|
+
subproject=_SUBPROJECT,
|
588
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
589
|
+
inspect.currentframe(), GaussianNB.__class__.__name__
|
590
|
+
),
|
591
|
+
api_calls=[Session.call],
|
592
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
593
|
+
)
|
594
|
+
if output_cols_prefix == "fit_predict_":
|
595
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
596
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
597
|
+
num_examples = self._sklearn_object.n_clusters
|
598
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
599
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
600
|
+
num_examples = self._sklearn_object.min_samples
|
601
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
602
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
603
|
+
num_examples = self._sklearn_object.n_neighbors
|
604
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
605
|
+
else:
|
606
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
569
607
|
|
570
608
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
571
609
|
# seen during the fit.
|
@@ -577,12 +615,14 @@ class GaussianNB(BaseTransformer):
|
|
577
615
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
578
616
|
if self.sample_weight_col:
|
579
617
|
output_df_columns_set -= set(self.sample_weight_col)
|
618
|
+
|
580
619
|
# if the dimension of inferred output column names is correct; use it
|
581
620
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
582
|
-
return expected_output_cols_list
|
621
|
+
return expected_output_cols_list, output_df_pd
|
583
622
|
# otherwise, use the sklearn estimator's output
|
584
623
|
else:
|
585
|
-
|
624
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
625
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
586
626
|
|
587
627
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
588
628
|
@telemetry.send_api_usage_telemetry(
|
@@ -630,7 +670,7 @@ class GaussianNB(BaseTransformer):
|
|
630
670
|
drop_input_cols=self._drop_input_cols,
|
631
671
|
expected_output_cols_type="float",
|
632
672
|
)
|
633
|
-
expected_output_cols = self.
|
673
|
+
expected_output_cols, _ = self._align_expected_output(
|
634
674
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
635
675
|
)
|
636
676
|
|
@@ -698,7 +738,7 @@ class GaussianNB(BaseTransformer):
|
|
698
738
|
drop_input_cols=self._drop_input_cols,
|
699
739
|
expected_output_cols_type="float",
|
700
740
|
)
|
701
|
-
expected_output_cols = self.
|
741
|
+
expected_output_cols, _ = self._align_expected_output(
|
702
742
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
703
743
|
)
|
704
744
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -761,7 +801,7 @@ class GaussianNB(BaseTransformer):
|
|
761
801
|
drop_input_cols=self._drop_input_cols,
|
762
802
|
expected_output_cols_type="float",
|
763
803
|
)
|
764
|
-
expected_output_cols = self.
|
804
|
+
expected_output_cols, _ = self._align_expected_output(
|
765
805
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
766
806
|
)
|
767
807
|
|
@@ -826,7 +866,7 @@ class GaussianNB(BaseTransformer):
|
|
826
866
|
drop_input_cols = self._drop_input_cols,
|
827
867
|
expected_output_cols_type="float",
|
828
868
|
)
|
829
|
-
expected_output_cols = self.
|
869
|
+
expected_output_cols, _ = self._align_expected_output(
|
830
870
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
831
871
|
)
|
832
872
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -489,12 +486,23 @@ class MultinomialNB(BaseTransformer):
|
|
489
486
|
autogenerated=self._autogenerated,
|
490
487
|
subproject=_SUBPROJECT,
|
491
488
|
)
|
492
|
-
|
493
|
-
|
494
|
-
expected_output_cols_list=(
|
495
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
496
|
-
),
|
489
|
+
expected_output_cols = (
|
490
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
497
491
|
)
|
492
|
+
if isinstance(dataset, DataFrame):
|
493
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
494
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
495
|
+
)
|
496
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
497
|
+
drop_input_cols=self._drop_input_cols,
|
498
|
+
expected_output_cols_list=expected_output_cols,
|
499
|
+
example_output_pd_df=example_output_pd_df,
|
500
|
+
)
|
501
|
+
else:
|
502
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
504
|
+
expected_output_cols_list=expected_output_cols,
|
505
|
+
)
|
498
506
|
self._sklearn_object = fitted_estimator
|
499
507
|
self._is_fitted = True
|
500
508
|
return output_result
|
@@ -517,6 +525,7 @@ class MultinomialNB(BaseTransformer):
|
|
517
525
|
"""
|
518
526
|
self._infer_input_output_cols(dataset)
|
519
527
|
super()._check_dataset_type(dataset)
|
528
|
+
|
520
529
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
521
530
|
estimator=self._sklearn_object,
|
522
531
|
dataset=dataset,
|
@@ -573,12 +582,41 @@ class MultinomialNB(BaseTransformer):
|
|
573
582
|
|
574
583
|
return rv
|
575
584
|
|
576
|
-
def
|
577
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
578
|
-
) -> List[str]:
|
585
|
+
def _align_expected_output(
|
586
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
587
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
588
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
589
|
+
and output dataframe with 1 line.
|
590
|
+
If the method is fit_predict, run 2 lines of data.
|
591
|
+
"""
|
579
592
|
# in case the inferred output column names dimension is different
|
580
593
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
581
|
-
|
594
|
+
|
595
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
596
|
+
# so change the minimum of number of rows to 2
|
597
|
+
num_examples = 2
|
598
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
599
|
+
project=_PROJECT,
|
600
|
+
subproject=_SUBPROJECT,
|
601
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
602
|
+
inspect.currentframe(), MultinomialNB.__class__.__name__
|
603
|
+
),
|
604
|
+
api_calls=[Session.call],
|
605
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
606
|
+
)
|
607
|
+
if output_cols_prefix == "fit_predict_":
|
608
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
609
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
610
|
+
num_examples = self._sklearn_object.n_clusters
|
611
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
612
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
613
|
+
num_examples = self._sklearn_object.min_samples
|
614
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
615
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
616
|
+
num_examples = self._sklearn_object.n_neighbors
|
617
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
618
|
+
else:
|
619
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
582
620
|
|
583
621
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
584
622
|
# seen during the fit.
|
@@ -590,12 +628,14 @@ class MultinomialNB(BaseTransformer):
|
|
590
628
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
591
629
|
if self.sample_weight_col:
|
592
630
|
output_df_columns_set -= set(self.sample_weight_col)
|
631
|
+
|
593
632
|
# if the dimension of inferred output column names is correct; use it
|
594
633
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
595
|
-
return expected_output_cols_list
|
634
|
+
return expected_output_cols_list, output_df_pd
|
596
635
|
# otherwise, use the sklearn estimator's output
|
597
636
|
else:
|
598
|
-
|
637
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
638
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
599
639
|
|
600
640
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
601
641
|
@telemetry.send_api_usage_telemetry(
|
@@ -643,7 +683,7 @@ class MultinomialNB(BaseTransformer):
|
|
643
683
|
drop_input_cols=self._drop_input_cols,
|
644
684
|
expected_output_cols_type="float",
|
645
685
|
)
|
646
|
-
expected_output_cols = self.
|
686
|
+
expected_output_cols, _ = self._align_expected_output(
|
647
687
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
648
688
|
)
|
649
689
|
|
@@ -711,7 +751,7 @@ class MultinomialNB(BaseTransformer):
|
|
711
751
|
drop_input_cols=self._drop_input_cols,
|
712
752
|
expected_output_cols_type="float",
|
713
753
|
)
|
714
|
-
expected_output_cols = self.
|
754
|
+
expected_output_cols, _ = self._align_expected_output(
|
715
755
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
716
756
|
)
|
717
757
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -774,7 +814,7 @@ class MultinomialNB(BaseTransformer):
|
|
774
814
|
drop_input_cols=self._drop_input_cols,
|
775
815
|
expected_output_cols_type="float",
|
776
816
|
)
|
777
|
-
expected_output_cols = self.
|
817
|
+
expected_output_cols, _ = self._align_expected_output(
|
778
818
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
779
819
|
)
|
780
820
|
|
@@ -839,7 +879,7 @@ class MultinomialNB(BaseTransformer):
|
|
839
879
|
drop_input_cols = self._drop_input_cols,
|
840
880
|
expected_output_cols_type="float",
|
841
881
|
)
|
842
|
-
expected_output_cols = self.
|
882
|
+
expected_output_cols, _ = self._align_expected_output(
|
843
883
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
844
884
|
)
|
845
885
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -546,12 +543,23 @@ class KNeighborsClassifier(BaseTransformer):
|
|
546
543
|
autogenerated=self._autogenerated,
|
547
544
|
subproject=_SUBPROJECT,
|
548
545
|
)
|
549
|
-
|
550
|
-
|
551
|
-
expected_output_cols_list=(
|
552
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
553
|
-
),
|
546
|
+
expected_output_cols = (
|
547
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
554
548
|
)
|
549
|
+
if isinstance(dataset, DataFrame):
|
550
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
551
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
552
|
+
)
|
553
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
554
|
+
drop_input_cols=self._drop_input_cols,
|
555
|
+
expected_output_cols_list=expected_output_cols,
|
556
|
+
example_output_pd_df=example_output_pd_df,
|
557
|
+
)
|
558
|
+
else:
|
559
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
560
|
+
drop_input_cols=self._drop_input_cols,
|
561
|
+
expected_output_cols_list=expected_output_cols,
|
562
|
+
)
|
555
563
|
self._sklearn_object = fitted_estimator
|
556
564
|
self._is_fitted = True
|
557
565
|
return output_result
|
@@ -574,6 +582,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
574
582
|
"""
|
575
583
|
self._infer_input_output_cols(dataset)
|
576
584
|
super()._check_dataset_type(dataset)
|
585
|
+
|
577
586
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
578
587
|
estimator=self._sklearn_object,
|
579
588
|
dataset=dataset,
|
@@ -630,12 +639,41 @@ class KNeighborsClassifier(BaseTransformer):
|
|
630
639
|
|
631
640
|
return rv
|
632
641
|
|
633
|
-
def
|
634
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
635
|
-
) -> List[str]:
|
642
|
+
def _align_expected_output(
|
643
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
644
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
645
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
646
|
+
and output dataframe with 1 line.
|
647
|
+
If the method is fit_predict, run 2 lines of data.
|
648
|
+
"""
|
636
649
|
# in case the inferred output column names dimension is different
|
637
650
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
638
|
-
|
651
|
+
|
652
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
653
|
+
# so change the minimum of number of rows to 2
|
654
|
+
num_examples = 2
|
655
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
656
|
+
project=_PROJECT,
|
657
|
+
subproject=_SUBPROJECT,
|
658
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
659
|
+
inspect.currentframe(), KNeighborsClassifier.__class__.__name__
|
660
|
+
),
|
661
|
+
api_calls=[Session.call],
|
662
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
663
|
+
)
|
664
|
+
if output_cols_prefix == "fit_predict_":
|
665
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
666
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
667
|
+
num_examples = self._sklearn_object.n_clusters
|
668
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
669
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
670
|
+
num_examples = self._sklearn_object.min_samples
|
671
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
672
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
673
|
+
num_examples = self._sklearn_object.n_neighbors
|
674
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
675
|
+
else:
|
676
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
639
677
|
|
640
678
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
641
679
|
# seen during the fit.
|
@@ -647,12 +685,14 @@ class KNeighborsClassifier(BaseTransformer):
|
|
647
685
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
648
686
|
if self.sample_weight_col:
|
649
687
|
output_df_columns_set -= set(self.sample_weight_col)
|
688
|
+
|
650
689
|
# if the dimension of inferred output column names is correct; use it
|
651
690
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
652
|
-
return expected_output_cols_list
|
691
|
+
return expected_output_cols_list, output_df_pd
|
653
692
|
# otherwise, use the sklearn estimator's output
|
654
693
|
else:
|
655
|
-
|
694
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
695
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
656
696
|
|
657
697
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
658
698
|
@telemetry.send_api_usage_telemetry(
|
@@ -700,7 +740,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
700
740
|
drop_input_cols=self._drop_input_cols,
|
701
741
|
expected_output_cols_type="float",
|
702
742
|
)
|
703
|
-
expected_output_cols = self.
|
743
|
+
expected_output_cols, _ = self._align_expected_output(
|
704
744
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
705
745
|
)
|
706
746
|
|
@@ -768,7 +808,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
768
808
|
drop_input_cols=self._drop_input_cols,
|
769
809
|
expected_output_cols_type="float",
|
770
810
|
)
|
771
|
-
expected_output_cols = self.
|
811
|
+
expected_output_cols, _ = self._align_expected_output(
|
772
812
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
773
813
|
)
|
774
814
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -831,7 +871,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
831
871
|
drop_input_cols=self._drop_input_cols,
|
832
872
|
expected_output_cols_type="float",
|
833
873
|
)
|
834
|
-
expected_output_cols = self.
|
874
|
+
expected_output_cols, _ = self._align_expected_output(
|
835
875
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
836
876
|
)
|
837
877
|
|
@@ -896,7 +936,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
896
936
|
drop_input_cols = self._drop_input_cols,
|
897
937
|
expected_output_cols_type="float",
|
898
938
|
)
|
899
|
-
expected_output_cols = self.
|
939
|
+
expected_output_cols, _ = self._align_expected_output(
|
900
940
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
901
941
|
)
|
902
942
|
|