snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -492,12 +489,23 @@ class OneVsRestClassifier(BaseTransformer):
|
|
492
489
|
autogenerated=self._autogenerated,
|
493
490
|
subproject=_SUBPROJECT,
|
494
491
|
)
|
495
|
-
|
496
|
-
|
497
|
-
expected_output_cols_list=(
|
498
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
499
|
-
),
|
492
|
+
expected_output_cols = (
|
493
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
500
494
|
)
|
495
|
+
if isinstance(dataset, DataFrame):
|
496
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
497
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
498
|
+
)
|
499
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
500
|
+
drop_input_cols=self._drop_input_cols,
|
501
|
+
expected_output_cols_list=expected_output_cols,
|
502
|
+
example_output_pd_df=example_output_pd_df,
|
503
|
+
)
|
504
|
+
else:
|
505
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
506
|
+
drop_input_cols=self._drop_input_cols,
|
507
|
+
expected_output_cols_list=expected_output_cols,
|
508
|
+
)
|
501
509
|
self._sklearn_object = fitted_estimator
|
502
510
|
self._is_fitted = True
|
503
511
|
return output_result
|
@@ -520,6 +528,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
520
528
|
"""
|
521
529
|
self._infer_input_output_cols(dataset)
|
522
530
|
super()._check_dataset_type(dataset)
|
531
|
+
|
523
532
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
524
533
|
estimator=self._sklearn_object,
|
525
534
|
dataset=dataset,
|
@@ -576,12 +585,41 @@ class OneVsRestClassifier(BaseTransformer):
|
|
576
585
|
|
577
586
|
return rv
|
578
587
|
|
579
|
-
def
|
580
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
581
|
-
) -> List[str]:
|
588
|
+
def _align_expected_output(
|
589
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
590
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
591
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
592
|
+
and output dataframe with 1 line.
|
593
|
+
If the method is fit_predict, run 2 lines of data.
|
594
|
+
"""
|
582
595
|
# in case the inferred output column names dimension is different
|
583
596
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
584
|
-
|
597
|
+
|
598
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
599
|
+
# so change the minimum of number of rows to 2
|
600
|
+
num_examples = 2
|
601
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
602
|
+
project=_PROJECT,
|
603
|
+
subproject=_SUBPROJECT,
|
604
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
605
|
+
inspect.currentframe(), OneVsRestClassifier.__class__.__name__
|
606
|
+
),
|
607
|
+
api_calls=[Session.call],
|
608
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
609
|
+
)
|
610
|
+
if output_cols_prefix == "fit_predict_":
|
611
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
612
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
613
|
+
num_examples = self._sklearn_object.n_clusters
|
614
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
615
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
616
|
+
num_examples = self._sklearn_object.min_samples
|
617
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
618
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
619
|
+
num_examples = self._sklearn_object.n_neighbors
|
620
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
621
|
+
else:
|
622
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
585
623
|
|
586
624
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
587
625
|
# seen during the fit.
|
@@ -593,12 +631,14 @@ class OneVsRestClassifier(BaseTransformer):
|
|
593
631
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
594
632
|
if self.sample_weight_col:
|
595
633
|
output_df_columns_set -= set(self.sample_weight_col)
|
634
|
+
|
596
635
|
# if the dimension of inferred output column names is correct; use it
|
597
636
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
598
|
-
return expected_output_cols_list
|
637
|
+
return expected_output_cols_list, output_df_pd
|
599
638
|
# otherwise, use the sklearn estimator's output
|
600
639
|
else:
|
601
|
-
|
640
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
641
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
602
642
|
|
603
643
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
604
644
|
@telemetry.send_api_usage_telemetry(
|
@@ -646,7 +686,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
646
686
|
drop_input_cols=self._drop_input_cols,
|
647
687
|
expected_output_cols_type="float",
|
648
688
|
)
|
649
|
-
expected_output_cols = self.
|
689
|
+
expected_output_cols, _ = self._align_expected_output(
|
650
690
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
651
691
|
)
|
652
692
|
|
@@ -714,7 +754,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
714
754
|
drop_input_cols=self._drop_input_cols,
|
715
755
|
expected_output_cols_type="float",
|
716
756
|
)
|
717
|
-
expected_output_cols = self.
|
757
|
+
expected_output_cols, _ = self._align_expected_output(
|
718
758
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
719
759
|
)
|
720
760
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -779,7 +819,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
779
819
|
drop_input_cols=self._drop_input_cols,
|
780
820
|
expected_output_cols_type="float",
|
781
821
|
)
|
782
|
-
expected_output_cols = self.
|
822
|
+
expected_output_cols, _ = self._align_expected_output(
|
783
823
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
784
824
|
)
|
785
825
|
|
@@ -844,7 +884,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
844
884
|
drop_input_cols = self._drop_input_cols,
|
845
885
|
expected_output_cols_type="float",
|
846
886
|
)
|
847
|
-
expected_output_cols = self.
|
887
|
+
expected_output_cols, _ = self._align_expected_output(
|
848
888
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
849
889
|
)
|
850
890
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -495,12 +492,23 @@ class OutputCodeClassifier(BaseTransformer):
|
|
495
492
|
autogenerated=self._autogenerated,
|
496
493
|
subproject=_SUBPROJECT,
|
497
494
|
)
|
498
|
-
|
499
|
-
|
500
|
-
expected_output_cols_list=(
|
501
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
502
|
-
),
|
495
|
+
expected_output_cols = (
|
496
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
503
497
|
)
|
498
|
+
if isinstance(dataset, DataFrame):
|
499
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
500
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
501
|
+
)
|
502
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
504
|
+
expected_output_cols_list=expected_output_cols,
|
505
|
+
example_output_pd_df=example_output_pd_df,
|
506
|
+
)
|
507
|
+
else:
|
508
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
+
drop_input_cols=self._drop_input_cols,
|
510
|
+
expected_output_cols_list=expected_output_cols,
|
511
|
+
)
|
504
512
|
self._sklearn_object = fitted_estimator
|
505
513
|
self._is_fitted = True
|
506
514
|
return output_result
|
@@ -523,6 +531,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
523
531
|
"""
|
524
532
|
self._infer_input_output_cols(dataset)
|
525
533
|
super()._check_dataset_type(dataset)
|
534
|
+
|
526
535
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
527
536
|
estimator=self._sklearn_object,
|
528
537
|
dataset=dataset,
|
@@ -579,12 +588,41 @@ class OutputCodeClassifier(BaseTransformer):
|
|
579
588
|
|
580
589
|
return rv
|
581
590
|
|
582
|
-
def
|
583
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
584
|
-
) -> List[str]:
|
591
|
+
def _align_expected_output(
|
592
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
593
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
594
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
595
|
+
and output dataframe with 1 line.
|
596
|
+
If the method is fit_predict, run 2 lines of data.
|
597
|
+
"""
|
585
598
|
# in case the inferred output column names dimension is different
|
586
599
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
587
|
-
|
600
|
+
|
601
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
602
|
+
# so change the minimum of number of rows to 2
|
603
|
+
num_examples = 2
|
604
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
605
|
+
project=_PROJECT,
|
606
|
+
subproject=_SUBPROJECT,
|
607
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
608
|
+
inspect.currentframe(), OutputCodeClassifier.__class__.__name__
|
609
|
+
),
|
610
|
+
api_calls=[Session.call],
|
611
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
612
|
+
)
|
613
|
+
if output_cols_prefix == "fit_predict_":
|
614
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
615
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
616
|
+
num_examples = self._sklearn_object.n_clusters
|
617
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
618
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
619
|
+
num_examples = self._sklearn_object.min_samples
|
620
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
621
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
622
|
+
num_examples = self._sklearn_object.n_neighbors
|
623
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
624
|
+
else:
|
625
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
588
626
|
|
589
627
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
590
628
|
# seen during the fit.
|
@@ -596,12 +634,14 @@ class OutputCodeClassifier(BaseTransformer):
|
|
596
634
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
597
635
|
if self.sample_weight_col:
|
598
636
|
output_df_columns_set -= set(self.sample_weight_col)
|
637
|
+
|
599
638
|
# if the dimension of inferred output column names is correct; use it
|
600
639
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
601
|
-
return expected_output_cols_list
|
640
|
+
return expected_output_cols_list, output_df_pd
|
602
641
|
# otherwise, use the sklearn estimator's output
|
603
642
|
else:
|
604
|
-
|
643
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
644
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
605
645
|
|
606
646
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
607
647
|
@telemetry.send_api_usage_telemetry(
|
@@ -647,7 +687,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
647
687
|
drop_input_cols=self._drop_input_cols,
|
648
688
|
expected_output_cols_type="float",
|
649
689
|
)
|
650
|
-
expected_output_cols = self.
|
690
|
+
expected_output_cols, _ = self._align_expected_output(
|
651
691
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
652
692
|
)
|
653
693
|
|
@@ -713,7 +753,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
713
753
|
drop_input_cols=self._drop_input_cols,
|
714
754
|
expected_output_cols_type="float",
|
715
755
|
)
|
716
|
-
expected_output_cols = self.
|
756
|
+
expected_output_cols, _ = self._align_expected_output(
|
717
757
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
718
758
|
)
|
719
759
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -776,7 +816,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
776
816
|
drop_input_cols=self._drop_input_cols,
|
777
817
|
expected_output_cols_type="float",
|
778
818
|
)
|
779
|
-
expected_output_cols = self.
|
819
|
+
expected_output_cols, _ = self._align_expected_output(
|
780
820
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
781
821
|
)
|
782
822
|
|
@@ -841,7 +881,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
841
881
|
drop_input_cols = self._drop_input_cols,
|
842
882
|
expected_output_cols_type="float",
|
843
883
|
)
|
844
|
-
expected_output_cols = self.
|
884
|
+
expected_output_cols, _ = self._align_expected_output(
|
845
885
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
846
886
|
)
|
847
887
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -495,12 +492,23 @@ class BernoulliNB(BaseTransformer):
|
|
495
492
|
autogenerated=self._autogenerated,
|
496
493
|
subproject=_SUBPROJECT,
|
497
494
|
)
|
498
|
-
|
499
|
-
|
500
|
-
expected_output_cols_list=(
|
501
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
502
|
-
),
|
495
|
+
expected_output_cols = (
|
496
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
503
497
|
)
|
498
|
+
if isinstance(dataset, DataFrame):
|
499
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
500
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
501
|
+
)
|
502
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
504
|
+
expected_output_cols_list=expected_output_cols,
|
505
|
+
example_output_pd_df=example_output_pd_df,
|
506
|
+
)
|
507
|
+
else:
|
508
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
+
drop_input_cols=self._drop_input_cols,
|
510
|
+
expected_output_cols_list=expected_output_cols,
|
511
|
+
)
|
504
512
|
self._sklearn_object = fitted_estimator
|
505
513
|
self._is_fitted = True
|
506
514
|
return output_result
|
@@ -523,6 +531,7 @@ class BernoulliNB(BaseTransformer):
|
|
523
531
|
"""
|
524
532
|
self._infer_input_output_cols(dataset)
|
525
533
|
super()._check_dataset_type(dataset)
|
534
|
+
|
526
535
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
527
536
|
estimator=self._sklearn_object,
|
528
537
|
dataset=dataset,
|
@@ -579,12 +588,41 @@ class BernoulliNB(BaseTransformer):
|
|
579
588
|
|
580
589
|
return rv
|
581
590
|
|
582
|
-
def
|
583
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
584
|
-
) -> List[str]:
|
591
|
+
def _align_expected_output(
|
592
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
593
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
594
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
595
|
+
and output dataframe with 1 line.
|
596
|
+
If the method is fit_predict, run 2 lines of data.
|
597
|
+
"""
|
585
598
|
# in case the inferred output column names dimension is different
|
586
599
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
587
|
-
|
600
|
+
|
601
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
602
|
+
# so change the minimum of number of rows to 2
|
603
|
+
num_examples = 2
|
604
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
605
|
+
project=_PROJECT,
|
606
|
+
subproject=_SUBPROJECT,
|
607
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
608
|
+
inspect.currentframe(), BernoulliNB.__class__.__name__
|
609
|
+
),
|
610
|
+
api_calls=[Session.call],
|
611
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
612
|
+
)
|
613
|
+
if output_cols_prefix == "fit_predict_":
|
614
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
615
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
616
|
+
num_examples = self._sklearn_object.n_clusters
|
617
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
618
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
619
|
+
num_examples = self._sklearn_object.min_samples
|
620
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
621
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
622
|
+
num_examples = self._sklearn_object.n_neighbors
|
623
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
624
|
+
else:
|
625
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
588
626
|
|
589
627
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
590
628
|
# seen during the fit.
|
@@ -596,12 +634,14 @@ class BernoulliNB(BaseTransformer):
|
|
596
634
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
597
635
|
if self.sample_weight_col:
|
598
636
|
output_df_columns_set -= set(self.sample_weight_col)
|
637
|
+
|
599
638
|
# if the dimension of inferred output column names is correct; use it
|
600
639
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
601
|
-
return expected_output_cols_list
|
640
|
+
return expected_output_cols_list, output_df_pd
|
602
641
|
# otherwise, use the sklearn estimator's output
|
603
642
|
else:
|
604
|
-
|
643
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
644
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
605
645
|
|
606
646
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
607
647
|
@telemetry.send_api_usage_telemetry(
|
@@ -649,7 +689,7 @@ class BernoulliNB(BaseTransformer):
|
|
649
689
|
drop_input_cols=self._drop_input_cols,
|
650
690
|
expected_output_cols_type="float",
|
651
691
|
)
|
652
|
-
expected_output_cols = self.
|
692
|
+
expected_output_cols, _ = self._align_expected_output(
|
653
693
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
654
694
|
)
|
655
695
|
|
@@ -717,7 +757,7 @@ class BernoulliNB(BaseTransformer):
|
|
717
757
|
drop_input_cols=self._drop_input_cols,
|
718
758
|
expected_output_cols_type="float",
|
719
759
|
)
|
720
|
-
expected_output_cols = self.
|
760
|
+
expected_output_cols, _ = self._align_expected_output(
|
721
761
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
722
762
|
)
|
723
763
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -780,7 +820,7 @@ class BernoulliNB(BaseTransformer):
|
|
780
820
|
drop_input_cols=self._drop_input_cols,
|
781
821
|
expected_output_cols_type="float",
|
782
822
|
)
|
783
|
-
expected_output_cols = self.
|
823
|
+
expected_output_cols, _ = self._align_expected_output(
|
784
824
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
785
825
|
)
|
786
826
|
|
@@ -845,7 +885,7 @@ class BernoulliNB(BaseTransformer):
|
|
845
885
|
drop_input_cols = self._drop_input_cols,
|
846
886
|
expected_output_cols_type="float",
|
847
887
|
)
|
848
|
-
expected_output_cols = self.
|
888
|
+
expected_output_cols, _ = self._align_expected_output(
|
849
889
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
850
890
|
)
|
851
891
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -501,12 +498,23 @@ class CategoricalNB(BaseTransformer):
|
|
501
498
|
autogenerated=self._autogenerated,
|
502
499
|
subproject=_SUBPROJECT,
|
503
500
|
)
|
504
|
-
|
505
|
-
|
506
|
-
expected_output_cols_list=(
|
507
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
508
|
-
),
|
501
|
+
expected_output_cols = (
|
502
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
509
503
|
)
|
504
|
+
if isinstance(dataset, DataFrame):
|
505
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
506
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
507
|
+
)
|
508
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
+
drop_input_cols=self._drop_input_cols,
|
510
|
+
expected_output_cols_list=expected_output_cols,
|
511
|
+
example_output_pd_df=example_output_pd_df,
|
512
|
+
)
|
513
|
+
else:
|
514
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
515
|
+
drop_input_cols=self._drop_input_cols,
|
516
|
+
expected_output_cols_list=expected_output_cols,
|
517
|
+
)
|
510
518
|
self._sklearn_object = fitted_estimator
|
511
519
|
self._is_fitted = True
|
512
520
|
return output_result
|
@@ -529,6 +537,7 @@ class CategoricalNB(BaseTransformer):
|
|
529
537
|
"""
|
530
538
|
self._infer_input_output_cols(dataset)
|
531
539
|
super()._check_dataset_type(dataset)
|
540
|
+
|
532
541
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
533
542
|
estimator=self._sklearn_object,
|
534
543
|
dataset=dataset,
|
@@ -585,12 +594,41 @@ class CategoricalNB(BaseTransformer):
|
|
585
594
|
|
586
595
|
return rv
|
587
596
|
|
588
|
-
def
|
589
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
590
|
-
) -> List[str]:
|
597
|
+
def _align_expected_output(
|
598
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
599
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
600
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
601
|
+
and output dataframe with 1 line.
|
602
|
+
If the method is fit_predict, run 2 lines of data.
|
603
|
+
"""
|
591
604
|
# in case the inferred output column names dimension is different
|
592
605
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
593
|
-
|
606
|
+
|
607
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
608
|
+
# so change the minimum of number of rows to 2
|
609
|
+
num_examples = 2
|
610
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
611
|
+
project=_PROJECT,
|
612
|
+
subproject=_SUBPROJECT,
|
613
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
614
|
+
inspect.currentframe(), CategoricalNB.__class__.__name__
|
615
|
+
),
|
616
|
+
api_calls=[Session.call],
|
617
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
618
|
+
)
|
619
|
+
if output_cols_prefix == "fit_predict_":
|
620
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
621
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
622
|
+
num_examples = self._sklearn_object.n_clusters
|
623
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
624
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
625
|
+
num_examples = self._sklearn_object.min_samples
|
626
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
627
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
628
|
+
num_examples = self._sklearn_object.n_neighbors
|
629
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
630
|
+
else:
|
631
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
594
632
|
|
595
633
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
596
634
|
# seen during the fit.
|
@@ -602,12 +640,14 @@ class CategoricalNB(BaseTransformer):
|
|
602
640
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
603
641
|
if self.sample_weight_col:
|
604
642
|
output_df_columns_set -= set(self.sample_weight_col)
|
643
|
+
|
605
644
|
# if the dimension of inferred output column names is correct; use it
|
606
645
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
607
|
-
return expected_output_cols_list
|
646
|
+
return expected_output_cols_list, output_df_pd
|
608
647
|
# otherwise, use the sklearn estimator's output
|
609
648
|
else:
|
610
|
-
|
649
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
650
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
611
651
|
|
612
652
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
613
653
|
@telemetry.send_api_usage_telemetry(
|
@@ -655,7 +695,7 @@ class CategoricalNB(BaseTransformer):
|
|
655
695
|
drop_input_cols=self._drop_input_cols,
|
656
696
|
expected_output_cols_type="float",
|
657
697
|
)
|
658
|
-
expected_output_cols = self.
|
698
|
+
expected_output_cols, _ = self._align_expected_output(
|
659
699
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
660
700
|
)
|
661
701
|
|
@@ -723,7 +763,7 @@ class CategoricalNB(BaseTransformer):
|
|
723
763
|
drop_input_cols=self._drop_input_cols,
|
724
764
|
expected_output_cols_type="float",
|
725
765
|
)
|
726
|
-
expected_output_cols = self.
|
766
|
+
expected_output_cols, _ = self._align_expected_output(
|
727
767
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
728
768
|
)
|
729
769
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -786,7 +826,7 @@ class CategoricalNB(BaseTransformer):
|
|
786
826
|
drop_input_cols=self._drop_input_cols,
|
787
827
|
expected_output_cols_type="float",
|
788
828
|
)
|
789
|
-
expected_output_cols = self.
|
829
|
+
expected_output_cols, _ = self._align_expected_output(
|
790
830
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
791
831
|
)
|
792
832
|
|
@@ -851,7 +891,7 @@ class CategoricalNB(BaseTransformer):
|
|
851
891
|
drop_input_cols = self._drop_input_cols,
|
852
892
|
expected_output_cols_type="float",
|
853
893
|
)
|
854
|
-
expected_output_cols = self.
|
894
|
+
expected_output_cols, _ = self._align_expected_output(
|
855
895
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
856
896
|
)
|
857
897
|
|