snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -495,12 +492,23 @@ class ComplementNB(BaseTransformer):
|
|
495
492
|
autogenerated=self._autogenerated,
|
496
493
|
subproject=_SUBPROJECT,
|
497
494
|
)
|
498
|
-
|
499
|
-
|
500
|
-
expected_output_cols_list=(
|
501
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
502
|
-
),
|
495
|
+
expected_output_cols = (
|
496
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
503
497
|
)
|
498
|
+
if isinstance(dataset, DataFrame):
|
499
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
500
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
501
|
+
)
|
502
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
504
|
+
expected_output_cols_list=expected_output_cols,
|
505
|
+
example_output_pd_df=example_output_pd_df,
|
506
|
+
)
|
507
|
+
else:
|
508
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
+
drop_input_cols=self._drop_input_cols,
|
510
|
+
expected_output_cols_list=expected_output_cols,
|
511
|
+
)
|
504
512
|
self._sklearn_object = fitted_estimator
|
505
513
|
self._is_fitted = True
|
506
514
|
return output_result
|
@@ -579,12 +587,41 @@ class ComplementNB(BaseTransformer):
|
|
579
587
|
|
580
588
|
return rv
|
581
589
|
|
582
|
-
def
|
583
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
584
|
-
) -> List[str]:
|
590
|
+
def _align_expected_output(
|
591
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
592
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
593
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
594
|
+
and output dataframe with 1 line.
|
595
|
+
If the method is fit_predict, run 2 lines of data.
|
596
|
+
"""
|
585
597
|
# in case the inferred output column names dimension is different
|
586
598
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
587
|
-
|
599
|
+
|
600
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
601
|
+
# so change the minimum of number of rows to 2
|
602
|
+
num_examples = 2
|
603
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
604
|
+
project=_PROJECT,
|
605
|
+
subproject=_SUBPROJECT,
|
606
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
607
|
+
inspect.currentframe(), ComplementNB.__class__.__name__
|
608
|
+
),
|
609
|
+
api_calls=[Session.call],
|
610
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
611
|
+
)
|
612
|
+
if output_cols_prefix == "fit_predict_":
|
613
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
614
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
615
|
+
num_examples = self._sklearn_object.n_clusters
|
616
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
617
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
618
|
+
num_examples = self._sklearn_object.min_samples
|
619
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
620
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
621
|
+
num_examples = self._sklearn_object.n_neighbors
|
622
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
623
|
+
else:
|
624
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
588
625
|
|
589
626
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
590
627
|
# seen during the fit.
|
@@ -596,12 +633,14 @@ class ComplementNB(BaseTransformer):
|
|
596
633
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
597
634
|
if self.sample_weight_col:
|
598
635
|
output_df_columns_set -= set(self.sample_weight_col)
|
636
|
+
|
599
637
|
# if the dimension of inferred output column names is correct; use it
|
600
638
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
601
|
-
return expected_output_cols_list
|
639
|
+
return expected_output_cols_list, output_df_pd
|
602
640
|
# otherwise, use the sklearn estimator's output
|
603
641
|
else:
|
604
|
-
|
642
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
643
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
605
644
|
|
606
645
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
607
646
|
@telemetry.send_api_usage_telemetry(
|
@@ -649,7 +688,7 @@ class ComplementNB(BaseTransformer):
|
|
649
688
|
drop_input_cols=self._drop_input_cols,
|
650
689
|
expected_output_cols_type="float",
|
651
690
|
)
|
652
|
-
expected_output_cols = self.
|
691
|
+
expected_output_cols, _ = self._align_expected_output(
|
653
692
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
654
693
|
)
|
655
694
|
|
@@ -717,7 +756,7 @@ class ComplementNB(BaseTransformer):
|
|
717
756
|
drop_input_cols=self._drop_input_cols,
|
718
757
|
expected_output_cols_type="float",
|
719
758
|
)
|
720
|
-
expected_output_cols = self.
|
759
|
+
expected_output_cols, _ = self._align_expected_output(
|
721
760
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
722
761
|
)
|
723
762
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -780,7 +819,7 @@ class ComplementNB(BaseTransformer):
|
|
780
819
|
drop_input_cols=self._drop_input_cols,
|
781
820
|
expected_output_cols_type="float",
|
782
821
|
)
|
783
|
-
expected_output_cols = self.
|
822
|
+
expected_output_cols, _ = self._align_expected_output(
|
784
823
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
785
824
|
)
|
786
825
|
|
@@ -845,7 +884,7 @@ class ComplementNB(BaseTransformer):
|
|
845
884
|
drop_input_cols = self._drop_input_cols,
|
846
885
|
expected_output_cols_type="float",
|
847
886
|
)
|
848
|
-
expected_output_cols = self.
|
887
|
+
expected_output_cols, _ = self._align_expected_output(
|
849
888
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
850
889
|
)
|
851
890
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -476,12 +473,23 @@ class GaussianNB(BaseTransformer):
|
|
476
473
|
autogenerated=self._autogenerated,
|
477
474
|
subproject=_SUBPROJECT,
|
478
475
|
)
|
479
|
-
|
480
|
-
|
481
|
-
expected_output_cols_list=(
|
482
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
483
|
-
),
|
476
|
+
expected_output_cols = (
|
477
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
484
478
|
)
|
479
|
+
if isinstance(dataset, DataFrame):
|
480
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
481
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
482
|
+
)
|
483
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
484
|
+
drop_input_cols=self._drop_input_cols,
|
485
|
+
expected_output_cols_list=expected_output_cols,
|
486
|
+
example_output_pd_df=example_output_pd_df,
|
487
|
+
)
|
488
|
+
else:
|
489
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
490
|
+
drop_input_cols=self._drop_input_cols,
|
491
|
+
expected_output_cols_list=expected_output_cols,
|
492
|
+
)
|
485
493
|
self._sklearn_object = fitted_estimator
|
486
494
|
self._is_fitted = True
|
487
495
|
return output_result
|
@@ -560,12 +568,41 @@ class GaussianNB(BaseTransformer):
|
|
560
568
|
|
561
569
|
return rv
|
562
570
|
|
563
|
-
def
|
564
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
565
|
-
) -> List[str]:
|
571
|
+
def _align_expected_output(
|
572
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
573
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
574
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
575
|
+
and output dataframe with 1 line.
|
576
|
+
If the method is fit_predict, run 2 lines of data.
|
577
|
+
"""
|
566
578
|
# in case the inferred output column names dimension is different
|
567
579
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
568
|
-
|
580
|
+
|
581
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
582
|
+
# so change the minimum of number of rows to 2
|
583
|
+
num_examples = 2
|
584
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
585
|
+
project=_PROJECT,
|
586
|
+
subproject=_SUBPROJECT,
|
587
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
588
|
+
inspect.currentframe(), GaussianNB.__class__.__name__
|
589
|
+
),
|
590
|
+
api_calls=[Session.call],
|
591
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
592
|
+
)
|
593
|
+
if output_cols_prefix == "fit_predict_":
|
594
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
595
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
596
|
+
num_examples = self._sklearn_object.n_clusters
|
597
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
598
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
599
|
+
num_examples = self._sklearn_object.min_samples
|
600
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
601
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
602
|
+
num_examples = self._sklearn_object.n_neighbors
|
603
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
604
|
+
else:
|
605
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
569
606
|
|
570
607
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
571
608
|
# seen during the fit.
|
@@ -577,12 +614,14 @@ class GaussianNB(BaseTransformer):
|
|
577
614
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
578
615
|
if self.sample_weight_col:
|
579
616
|
output_df_columns_set -= set(self.sample_weight_col)
|
617
|
+
|
580
618
|
# if the dimension of inferred output column names is correct; use it
|
581
619
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
582
|
-
return expected_output_cols_list
|
620
|
+
return expected_output_cols_list, output_df_pd
|
583
621
|
# otherwise, use the sklearn estimator's output
|
584
622
|
else:
|
585
|
-
|
623
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
624
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
586
625
|
|
587
626
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
588
627
|
@telemetry.send_api_usage_telemetry(
|
@@ -630,7 +669,7 @@ class GaussianNB(BaseTransformer):
|
|
630
669
|
drop_input_cols=self._drop_input_cols,
|
631
670
|
expected_output_cols_type="float",
|
632
671
|
)
|
633
|
-
expected_output_cols = self.
|
672
|
+
expected_output_cols, _ = self._align_expected_output(
|
634
673
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
635
674
|
)
|
636
675
|
|
@@ -698,7 +737,7 @@ class GaussianNB(BaseTransformer):
|
|
698
737
|
drop_input_cols=self._drop_input_cols,
|
699
738
|
expected_output_cols_type="float",
|
700
739
|
)
|
701
|
-
expected_output_cols = self.
|
740
|
+
expected_output_cols, _ = self._align_expected_output(
|
702
741
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
703
742
|
)
|
704
743
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -761,7 +800,7 @@ class GaussianNB(BaseTransformer):
|
|
761
800
|
drop_input_cols=self._drop_input_cols,
|
762
801
|
expected_output_cols_type="float",
|
763
802
|
)
|
764
|
-
expected_output_cols = self.
|
803
|
+
expected_output_cols, _ = self._align_expected_output(
|
765
804
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
766
805
|
)
|
767
806
|
|
@@ -826,7 +865,7 @@ class GaussianNB(BaseTransformer):
|
|
826
865
|
drop_input_cols = self._drop_input_cols,
|
827
866
|
expected_output_cols_type="float",
|
828
867
|
)
|
829
|
-
expected_output_cols = self.
|
868
|
+
expected_output_cols, _ = self._align_expected_output(
|
830
869
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
831
870
|
)
|
832
871
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -489,12 +486,23 @@ class MultinomialNB(BaseTransformer):
|
|
489
486
|
autogenerated=self._autogenerated,
|
490
487
|
subproject=_SUBPROJECT,
|
491
488
|
)
|
492
|
-
|
493
|
-
|
494
|
-
expected_output_cols_list=(
|
495
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
496
|
-
),
|
489
|
+
expected_output_cols = (
|
490
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
497
491
|
)
|
492
|
+
if isinstance(dataset, DataFrame):
|
493
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
494
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
495
|
+
)
|
496
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
497
|
+
drop_input_cols=self._drop_input_cols,
|
498
|
+
expected_output_cols_list=expected_output_cols,
|
499
|
+
example_output_pd_df=example_output_pd_df,
|
500
|
+
)
|
501
|
+
else:
|
502
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
504
|
+
expected_output_cols_list=expected_output_cols,
|
505
|
+
)
|
498
506
|
self._sklearn_object = fitted_estimator
|
499
507
|
self._is_fitted = True
|
500
508
|
return output_result
|
@@ -573,12 +581,41 @@ class MultinomialNB(BaseTransformer):
|
|
573
581
|
|
574
582
|
return rv
|
575
583
|
|
576
|
-
def
|
577
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
578
|
-
) -> List[str]:
|
584
|
+
def _align_expected_output(
|
585
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
586
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
587
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
588
|
+
and output dataframe with 1 line.
|
589
|
+
If the method is fit_predict, run 2 lines of data.
|
590
|
+
"""
|
579
591
|
# in case the inferred output column names dimension is different
|
580
592
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
581
|
-
|
593
|
+
|
594
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
595
|
+
# so change the minimum of number of rows to 2
|
596
|
+
num_examples = 2
|
597
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
598
|
+
project=_PROJECT,
|
599
|
+
subproject=_SUBPROJECT,
|
600
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
601
|
+
inspect.currentframe(), MultinomialNB.__class__.__name__
|
602
|
+
),
|
603
|
+
api_calls=[Session.call],
|
604
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
605
|
+
)
|
606
|
+
if output_cols_prefix == "fit_predict_":
|
607
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
608
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
609
|
+
num_examples = self._sklearn_object.n_clusters
|
610
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
611
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
612
|
+
num_examples = self._sklearn_object.min_samples
|
613
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
614
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
615
|
+
num_examples = self._sklearn_object.n_neighbors
|
616
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
617
|
+
else:
|
618
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
582
619
|
|
583
620
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
584
621
|
# seen during the fit.
|
@@ -590,12 +627,14 @@ class MultinomialNB(BaseTransformer):
|
|
590
627
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
591
628
|
if self.sample_weight_col:
|
592
629
|
output_df_columns_set -= set(self.sample_weight_col)
|
630
|
+
|
593
631
|
# if the dimension of inferred output column names is correct; use it
|
594
632
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
595
|
-
return expected_output_cols_list
|
633
|
+
return expected_output_cols_list, output_df_pd
|
596
634
|
# otherwise, use the sklearn estimator's output
|
597
635
|
else:
|
598
|
-
|
636
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
637
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
599
638
|
|
600
639
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
601
640
|
@telemetry.send_api_usage_telemetry(
|
@@ -643,7 +682,7 @@ class MultinomialNB(BaseTransformer):
|
|
643
682
|
drop_input_cols=self._drop_input_cols,
|
644
683
|
expected_output_cols_type="float",
|
645
684
|
)
|
646
|
-
expected_output_cols = self.
|
685
|
+
expected_output_cols, _ = self._align_expected_output(
|
647
686
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
648
687
|
)
|
649
688
|
|
@@ -711,7 +750,7 @@ class MultinomialNB(BaseTransformer):
|
|
711
750
|
drop_input_cols=self._drop_input_cols,
|
712
751
|
expected_output_cols_type="float",
|
713
752
|
)
|
714
|
-
expected_output_cols = self.
|
753
|
+
expected_output_cols, _ = self._align_expected_output(
|
715
754
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
716
755
|
)
|
717
756
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -774,7 +813,7 @@ class MultinomialNB(BaseTransformer):
|
|
774
813
|
drop_input_cols=self._drop_input_cols,
|
775
814
|
expected_output_cols_type="float",
|
776
815
|
)
|
777
|
-
expected_output_cols = self.
|
816
|
+
expected_output_cols, _ = self._align_expected_output(
|
778
817
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
779
818
|
)
|
780
819
|
|
@@ -839,7 +878,7 @@ class MultinomialNB(BaseTransformer):
|
|
839
878
|
drop_input_cols = self._drop_input_cols,
|
840
879
|
expected_output_cols_type="float",
|
841
880
|
)
|
842
|
-
expected_output_cols = self.
|
881
|
+
expected_output_cols, _ = self._align_expected_output(
|
843
882
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
844
883
|
)
|
845
884
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -546,12 +543,23 @@ class KNeighborsClassifier(BaseTransformer):
|
|
546
543
|
autogenerated=self._autogenerated,
|
547
544
|
subproject=_SUBPROJECT,
|
548
545
|
)
|
549
|
-
|
550
|
-
|
551
|
-
expected_output_cols_list=(
|
552
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
553
|
-
),
|
546
|
+
expected_output_cols = (
|
547
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
554
548
|
)
|
549
|
+
if isinstance(dataset, DataFrame):
|
550
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
551
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
552
|
+
)
|
553
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
554
|
+
drop_input_cols=self._drop_input_cols,
|
555
|
+
expected_output_cols_list=expected_output_cols,
|
556
|
+
example_output_pd_df=example_output_pd_df,
|
557
|
+
)
|
558
|
+
else:
|
559
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
560
|
+
drop_input_cols=self._drop_input_cols,
|
561
|
+
expected_output_cols_list=expected_output_cols,
|
562
|
+
)
|
555
563
|
self._sklearn_object = fitted_estimator
|
556
564
|
self._is_fitted = True
|
557
565
|
return output_result
|
@@ -630,12 +638,41 @@ class KNeighborsClassifier(BaseTransformer):
|
|
630
638
|
|
631
639
|
return rv
|
632
640
|
|
633
|
-
def
|
634
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
635
|
-
) -> List[str]:
|
641
|
+
def _align_expected_output(
|
642
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
643
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
644
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
645
|
+
and output dataframe with 1 line.
|
646
|
+
If the method is fit_predict, run 2 lines of data.
|
647
|
+
"""
|
636
648
|
# in case the inferred output column names dimension is different
|
637
649
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
638
|
-
|
650
|
+
|
651
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
652
|
+
# so change the minimum of number of rows to 2
|
653
|
+
num_examples = 2
|
654
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
655
|
+
project=_PROJECT,
|
656
|
+
subproject=_SUBPROJECT,
|
657
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
658
|
+
inspect.currentframe(), KNeighborsClassifier.__class__.__name__
|
659
|
+
),
|
660
|
+
api_calls=[Session.call],
|
661
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
662
|
+
)
|
663
|
+
if output_cols_prefix == "fit_predict_":
|
664
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
665
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
666
|
+
num_examples = self._sklearn_object.n_clusters
|
667
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
668
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
669
|
+
num_examples = self._sklearn_object.min_samples
|
670
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
671
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
672
|
+
num_examples = self._sklearn_object.n_neighbors
|
673
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
674
|
+
else:
|
675
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
639
676
|
|
640
677
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
641
678
|
# seen during the fit.
|
@@ -647,12 +684,14 @@ class KNeighborsClassifier(BaseTransformer):
|
|
647
684
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
648
685
|
if self.sample_weight_col:
|
649
686
|
output_df_columns_set -= set(self.sample_weight_col)
|
687
|
+
|
650
688
|
# if the dimension of inferred output column names is correct; use it
|
651
689
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
652
|
-
return expected_output_cols_list
|
690
|
+
return expected_output_cols_list, output_df_pd
|
653
691
|
# otherwise, use the sklearn estimator's output
|
654
692
|
else:
|
655
|
-
|
693
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
694
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
656
695
|
|
657
696
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
658
697
|
@telemetry.send_api_usage_telemetry(
|
@@ -700,7 +739,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
700
739
|
drop_input_cols=self._drop_input_cols,
|
701
740
|
expected_output_cols_type="float",
|
702
741
|
)
|
703
|
-
expected_output_cols = self.
|
742
|
+
expected_output_cols, _ = self._align_expected_output(
|
704
743
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
705
744
|
)
|
706
745
|
|
@@ -768,7 +807,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
768
807
|
drop_input_cols=self._drop_input_cols,
|
769
808
|
expected_output_cols_type="float",
|
770
809
|
)
|
771
|
-
expected_output_cols = self.
|
810
|
+
expected_output_cols, _ = self._align_expected_output(
|
772
811
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
773
812
|
)
|
774
813
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -831,7 +870,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
831
870
|
drop_input_cols=self._drop_input_cols,
|
832
871
|
expected_output_cols_type="float",
|
833
872
|
)
|
834
|
-
expected_output_cols = self.
|
873
|
+
expected_output_cols, _ = self._align_expected_output(
|
835
874
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
836
875
|
)
|
837
876
|
|
@@ -896,7 +935,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
896
935
|
drop_input_cols = self._drop_input_cols,
|
897
936
|
expected_output_cols_type="float",
|
898
937
|
)
|
899
|
-
expected_output_cols = self.
|
938
|
+
expected_output_cols, _ = self._align_expected_output(
|
900
939
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
901
940
|
)
|
902
941
|
|