snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -492,12 +489,23 @@ class OneVsRestClassifier(BaseTransformer):
|
|
492
489
|
autogenerated=self._autogenerated,
|
493
490
|
subproject=_SUBPROJECT,
|
494
491
|
)
|
495
|
-
|
496
|
-
|
497
|
-
expected_output_cols_list=(
|
498
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
499
|
-
),
|
492
|
+
expected_output_cols = (
|
493
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
500
494
|
)
|
495
|
+
if isinstance(dataset, DataFrame):
|
496
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
497
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
498
|
+
)
|
499
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
500
|
+
drop_input_cols=self._drop_input_cols,
|
501
|
+
expected_output_cols_list=expected_output_cols,
|
502
|
+
example_output_pd_df=example_output_pd_df,
|
503
|
+
)
|
504
|
+
else:
|
505
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
506
|
+
drop_input_cols=self._drop_input_cols,
|
507
|
+
expected_output_cols_list=expected_output_cols,
|
508
|
+
)
|
501
509
|
self._sklearn_object = fitted_estimator
|
502
510
|
self._is_fitted = True
|
503
511
|
return output_result
|
@@ -576,12 +584,41 @@ class OneVsRestClassifier(BaseTransformer):
|
|
576
584
|
|
577
585
|
return rv
|
578
586
|
|
579
|
-
def
|
580
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
581
|
-
) -> List[str]:
|
587
|
+
def _align_expected_output(
|
588
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
589
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
590
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
591
|
+
and output dataframe with 1 line.
|
592
|
+
If the method is fit_predict, run 2 lines of data.
|
593
|
+
"""
|
582
594
|
# in case the inferred output column names dimension is different
|
583
595
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
584
|
-
|
596
|
+
|
597
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
598
|
+
# so change the minimum of number of rows to 2
|
599
|
+
num_examples = 2
|
600
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
601
|
+
project=_PROJECT,
|
602
|
+
subproject=_SUBPROJECT,
|
603
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
604
|
+
inspect.currentframe(), OneVsRestClassifier.__class__.__name__
|
605
|
+
),
|
606
|
+
api_calls=[Session.call],
|
607
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
608
|
+
)
|
609
|
+
if output_cols_prefix == "fit_predict_":
|
610
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
611
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
612
|
+
num_examples = self._sklearn_object.n_clusters
|
613
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
614
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
615
|
+
num_examples = self._sklearn_object.min_samples
|
616
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
617
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
618
|
+
num_examples = self._sklearn_object.n_neighbors
|
619
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
620
|
+
else:
|
621
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
585
622
|
|
586
623
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
587
624
|
# seen during the fit.
|
@@ -593,12 +630,14 @@ class OneVsRestClassifier(BaseTransformer):
|
|
593
630
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
594
631
|
if self.sample_weight_col:
|
595
632
|
output_df_columns_set -= set(self.sample_weight_col)
|
633
|
+
|
596
634
|
# if the dimension of inferred output column names is correct; use it
|
597
635
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
598
|
-
return expected_output_cols_list
|
636
|
+
return expected_output_cols_list, output_df_pd
|
599
637
|
# otherwise, use the sklearn estimator's output
|
600
638
|
else:
|
601
|
-
|
639
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
640
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
602
641
|
|
603
642
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
604
643
|
@telemetry.send_api_usage_telemetry(
|
@@ -646,7 +685,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
646
685
|
drop_input_cols=self._drop_input_cols,
|
647
686
|
expected_output_cols_type="float",
|
648
687
|
)
|
649
|
-
expected_output_cols = self.
|
688
|
+
expected_output_cols, _ = self._align_expected_output(
|
650
689
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
651
690
|
)
|
652
691
|
|
@@ -714,7 +753,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
714
753
|
drop_input_cols=self._drop_input_cols,
|
715
754
|
expected_output_cols_type="float",
|
716
755
|
)
|
717
|
-
expected_output_cols = self.
|
756
|
+
expected_output_cols, _ = self._align_expected_output(
|
718
757
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
719
758
|
)
|
720
759
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -779,7 +818,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
779
818
|
drop_input_cols=self._drop_input_cols,
|
780
819
|
expected_output_cols_type="float",
|
781
820
|
)
|
782
|
-
expected_output_cols = self.
|
821
|
+
expected_output_cols, _ = self._align_expected_output(
|
783
822
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
784
823
|
)
|
785
824
|
|
@@ -844,7 +883,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
844
883
|
drop_input_cols = self._drop_input_cols,
|
845
884
|
expected_output_cols_type="float",
|
846
885
|
)
|
847
|
-
expected_output_cols = self.
|
886
|
+
expected_output_cols, _ = self._align_expected_output(
|
848
887
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
849
888
|
)
|
850
889
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -495,12 +492,23 @@ class OutputCodeClassifier(BaseTransformer):
|
|
495
492
|
autogenerated=self._autogenerated,
|
496
493
|
subproject=_SUBPROJECT,
|
497
494
|
)
|
498
|
-
|
499
|
-
|
500
|
-
expected_output_cols_list=(
|
501
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
502
|
-
),
|
495
|
+
expected_output_cols = (
|
496
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
503
497
|
)
|
498
|
+
if isinstance(dataset, DataFrame):
|
499
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
500
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
501
|
+
)
|
502
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
504
|
+
expected_output_cols_list=expected_output_cols,
|
505
|
+
example_output_pd_df=example_output_pd_df,
|
506
|
+
)
|
507
|
+
else:
|
508
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
+
drop_input_cols=self._drop_input_cols,
|
510
|
+
expected_output_cols_list=expected_output_cols,
|
511
|
+
)
|
504
512
|
self._sklearn_object = fitted_estimator
|
505
513
|
self._is_fitted = True
|
506
514
|
return output_result
|
@@ -579,12 +587,41 @@ class OutputCodeClassifier(BaseTransformer):
|
|
579
587
|
|
580
588
|
return rv
|
581
589
|
|
582
|
-
def
|
583
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
584
|
-
) -> List[str]:
|
590
|
+
def _align_expected_output(
|
591
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
592
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
593
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
594
|
+
and output dataframe with 1 line.
|
595
|
+
If the method is fit_predict, run 2 lines of data.
|
596
|
+
"""
|
585
597
|
# in case the inferred output column names dimension is different
|
586
598
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
587
|
-
|
599
|
+
|
600
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
601
|
+
# so change the minimum of number of rows to 2
|
602
|
+
num_examples = 2
|
603
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
604
|
+
project=_PROJECT,
|
605
|
+
subproject=_SUBPROJECT,
|
606
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
607
|
+
inspect.currentframe(), OutputCodeClassifier.__class__.__name__
|
608
|
+
),
|
609
|
+
api_calls=[Session.call],
|
610
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
611
|
+
)
|
612
|
+
if output_cols_prefix == "fit_predict_":
|
613
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
614
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
615
|
+
num_examples = self._sklearn_object.n_clusters
|
616
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
617
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
618
|
+
num_examples = self._sklearn_object.min_samples
|
619
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
620
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
621
|
+
num_examples = self._sklearn_object.n_neighbors
|
622
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
623
|
+
else:
|
624
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
588
625
|
|
589
626
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
590
627
|
# seen during the fit.
|
@@ -596,12 +633,14 @@ class OutputCodeClassifier(BaseTransformer):
|
|
596
633
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
597
634
|
if self.sample_weight_col:
|
598
635
|
output_df_columns_set -= set(self.sample_weight_col)
|
636
|
+
|
599
637
|
# if the dimension of inferred output column names is correct; use it
|
600
638
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
601
|
-
return expected_output_cols_list
|
639
|
+
return expected_output_cols_list, output_df_pd
|
602
640
|
# otherwise, use the sklearn estimator's output
|
603
641
|
else:
|
604
|
-
|
642
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
643
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
605
644
|
|
606
645
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
607
646
|
@telemetry.send_api_usage_telemetry(
|
@@ -647,7 +686,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
647
686
|
drop_input_cols=self._drop_input_cols,
|
648
687
|
expected_output_cols_type="float",
|
649
688
|
)
|
650
|
-
expected_output_cols = self.
|
689
|
+
expected_output_cols, _ = self._align_expected_output(
|
651
690
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
652
691
|
)
|
653
692
|
|
@@ -713,7 +752,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
713
752
|
drop_input_cols=self._drop_input_cols,
|
714
753
|
expected_output_cols_type="float",
|
715
754
|
)
|
716
|
-
expected_output_cols = self.
|
755
|
+
expected_output_cols, _ = self._align_expected_output(
|
717
756
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
718
757
|
)
|
719
758
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -776,7 +815,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
776
815
|
drop_input_cols=self._drop_input_cols,
|
777
816
|
expected_output_cols_type="float",
|
778
817
|
)
|
779
|
-
expected_output_cols = self.
|
818
|
+
expected_output_cols, _ = self._align_expected_output(
|
780
819
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
781
820
|
)
|
782
821
|
|
@@ -841,7 +880,7 @@ class OutputCodeClassifier(BaseTransformer):
|
|
841
880
|
drop_input_cols = self._drop_input_cols,
|
842
881
|
expected_output_cols_type="float",
|
843
882
|
)
|
844
|
-
expected_output_cols = self.
|
883
|
+
expected_output_cols, _ = self._align_expected_output(
|
845
884
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
846
885
|
)
|
847
886
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -495,12 +492,23 @@ class BernoulliNB(BaseTransformer):
|
|
495
492
|
autogenerated=self._autogenerated,
|
496
493
|
subproject=_SUBPROJECT,
|
497
494
|
)
|
498
|
-
|
499
|
-
|
500
|
-
expected_output_cols_list=(
|
501
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
502
|
-
),
|
495
|
+
expected_output_cols = (
|
496
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
503
497
|
)
|
498
|
+
if isinstance(dataset, DataFrame):
|
499
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
500
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
501
|
+
)
|
502
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
504
|
+
expected_output_cols_list=expected_output_cols,
|
505
|
+
example_output_pd_df=example_output_pd_df,
|
506
|
+
)
|
507
|
+
else:
|
508
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
+
drop_input_cols=self._drop_input_cols,
|
510
|
+
expected_output_cols_list=expected_output_cols,
|
511
|
+
)
|
504
512
|
self._sklearn_object = fitted_estimator
|
505
513
|
self._is_fitted = True
|
506
514
|
return output_result
|
@@ -579,12 +587,41 @@ class BernoulliNB(BaseTransformer):
|
|
579
587
|
|
580
588
|
return rv
|
581
589
|
|
582
|
-
def
|
583
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
584
|
-
) -> List[str]:
|
590
|
+
def _align_expected_output(
|
591
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
592
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
593
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
594
|
+
and output dataframe with 1 line.
|
595
|
+
If the method is fit_predict, run 2 lines of data.
|
596
|
+
"""
|
585
597
|
# in case the inferred output column names dimension is different
|
586
598
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
587
|
-
|
599
|
+
|
600
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
601
|
+
# so change the minimum of number of rows to 2
|
602
|
+
num_examples = 2
|
603
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
604
|
+
project=_PROJECT,
|
605
|
+
subproject=_SUBPROJECT,
|
606
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
607
|
+
inspect.currentframe(), BernoulliNB.__class__.__name__
|
608
|
+
),
|
609
|
+
api_calls=[Session.call],
|
610
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
611
|
+
)
|
612
|
+
if output_cols_prefix == "fit_predict_":
|
613
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
614
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
615
|
+
num_examples = self._sklearn_object.n_clusters
|
616
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
617
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
618
|
+
num_examples = self._sklearn_object.min_samples
|
619
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
620
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
621
|
+
num_examples = self._sklearn_object.n_neighbors
|
622
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
623
|
+
else:
|
624
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
588
625
|
|
589
626
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
590
627
|
# seen during the fit.
|
@@ -596,12 +633,14 @@ class BernoulliNB(BaseTransformer):
|
|
596
633
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
597
634
|
if self.sample_weight_col:
|
598
635
|
output_df_columns_set -= set(self.sample_weight_col)
|
636
|
+
|
599
637
|
# if the dimension of inferred output column names is correct; use it
|
600
638
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
601
|
-
return expected_output_cols_list
|
639
|
+
return expected_output_cols_list, output_df_pd
|
602
640
|
# otherwise, use the sklearn estimator's output
|
603
641
|
else:
|
604
|
-
|
642
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
643
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
605
644
|
|
606
645
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
607
646
|
@telemetry.send_api_usage_telemetry(
|
@@ -649,7 +688,7 @@ class BernoulliNB(BaseTransformer):
|
|
649
688
|
drop_input_cols=self._drop_input_cols,
|
650
689
|
expected_output_cols_type="float",
|
651
690
|
)
|
652
|
-
expected_output_cols = self.
|
691
|
+
expected_output_cols, _ = self._align_expected_output(
|
653
692
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
654
693
|
)
|
655
694
|
|
@@ -717,7 +756,7 @@ class BernoulliNB(BaseTransformer):
|
|
717
756
|
drop_input_cols=self._drop_input_cols,
|
718
757
|
expected_output_cols_type="float",
|
719
758
|
)
|
720
|
-
expected_output_cols = self.
|
759
|
+
expected_output_cols, _ = self._align_expected_output(
|
721
760
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
722
761
|
)
|
723
762
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -780,7 +819,7 @@ class BernoulliNB(BaseTransformer):
|
|
780
819
|
drop_input_cols=self._drop_input_cols,
|
781
820
|
expected_output_cols_type="float",
|
782
821
|
)
|
783
|
-
expected_output_cols = self.
|
822
|
+
expected_output_cols, _ = self._align_expected_output(
|
784
823
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
785
824
|
)
|
786
825
|
|
@@ -845,7 +884,7 @@ class BernoulliNB(BaseTransformer):
|
|
845
884
|
drop_input_cols = self._drop_input_cols,
|
846
885
|
expected_output_cols_type="float",
|
847
886
|
)
|
848
|
-
expected_output_cols = self.
|
887
|
+
expected_output_cols, _ = self._align_expected_output(
|
849
888
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
850
889
|
)
|
851
890
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -501,12 +498,23 @@ class CategoricalNB(BaseTransformer):
|
|
501
498
|
autogenerated=self._autogenerated,
|
502
499
|
subproject=_SUBPROJECT,
|
503
500
|
)
|
504
|
-
|
505
|
-
|
506
|
-
expected_output_cols_list=(
|
507
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
508
|
-
),
|
501
|
+
expected_output_cols = (
|
502
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
509
503
|
)
|
504
|
+
if isinstance(dataset, DataFrame):
|
505
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
506
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
507
|
+
)
|
508
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
+
drop_input_cols=self._drop_input_cols,
|
510
|
+
expected_output_cols_list=expected_output_cols,
|
511
|
+
example_output_pd_df=example_output_pd_df,
|
512
|
+
)
|
513
|
+
else:
|
514
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
515
|
+
drop_input_cols=self._drop_input_cols,
|
516
|
+
expected_output_cols_list=expected_output_cols,
|
517
|
+
)
|
510
518
|
self._sklearn_object = fitted_estimator
|
511
519
|
self._is_fitted = True
|
512
520
|
return output_result
|
@@ -585,12 +593,41 @@ class CategoricalNB(BaseTransformer):
|
|
585
593
|
|
586
594
|
return rv
|
587
595
|
|
588
|
-
def
|
589
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
590
|
-
) -> List[str]:
|
596
|
+
def _align_expected_output(
|
597
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
598
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
599
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
600
|
+
and output dataframe with 1 line.
|
601
|
+
If the method is fit_predict, run 2 lines of data.
|
602
|
+
"""
|
591
603
|
# in case the inferred output column names dimension is different
|
592
604
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
593
|
-
|
605
|
+
|
606
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
607
|
+
# so change the minimum of number of rows to 2
|
608
|
+
num_examples = 2
|
609
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
610
|
+
project=_PROJECT,
|
611
|
+
subproject=_SUBPROJECT,
|
612
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
613
|
+
inspect.currentframe(), CategoricalNB.__class__.__name__
|
614
|
+
),
|
615
|
+
api_calls=[Session.call],
|
616
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
617
|
+
)
|
618
|
+
if output_cols_prefix == "fit_predict_":
|
619
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
620
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
621
|
+
num_examples = self._sklearn_object.n_clusters
|
622
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
623
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
624
|
+
num_examples = self._sklearn_object.min_samples
|
625
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
626
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
627
|
+
num_examples = self._sklearn_object.n_neighbors
|
628
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
629
|
+
else:
|
630
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
594
631
|
|
595
632
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
596
633
|
# seen during the fit.
|
@@ -602,12 +639,14 @@ class CategoricalNB(BaseTransformer):
|
|
602
639
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
603
640
|
if self.sample_weight_col:
|
604
641
|
output_df_columns_set -= set(self.sample_weight_col)
|
642
|
+
|
605
643
|
# if the dimension of inferred output column names is correct; use it
|
606
644
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
607
|
-
return expected_output_cols_list
|
645
|
+
return expected_output_cols_list, output_df_pd
|
608
646
|
# otherwise, use the sklearn estimator's output
|
609
647
|
else:
|
610
|
-
|
648
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
649
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
611
650
|
|
612
651
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
613
652
|
@telemetry.send_api_usage_telemetry(
|
@@ -655,7 +694,7 @@ class CategoricalNB(BaseTransformer):
|
|
655
694
|
drop_input_cols=self._drop_input_cols,
|
656
695
|
expected_output_cols_type="float",
|
657
696
|
)
|
658
|
-
expected_output_cols = self.
|
697
|
+
expected_output_cols, _ = self._align_expected_output(
|
659
698
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
660
699
|
)
|
661
700
|
|
@@ -723,7 +762,7 @@ class CategoricalNB(BaseTransformer):
|
|
723
762
|
drop_input_cols=self._drop_input_cols,
|
724
763
|
expected_output_cols_type="float",
|
725
764
|
)
|
726
|
-
expected_output_cols = self.
|
765
|
+
expected_output_cols, _ = self._align_expected_output(
|
727
766
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
728
767
|
)
|
729
768
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -786,7 +825,7 @@ class CategoricalNB(BaseTransformer):
|
|
786
825
|
drop_input_cols=self._drop_input_cols,
|
787
826
|
expected_output_cols_type="float",
|
788
827
|
)
|
789
|
-
expected_output_cols = self.
|
828
|
+
expected_output_cols, _ = self._align_expected_output(
|
790
829
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
791
830
|
)
|
792
831
|
|
@@ -851,7 +890,7 @@ class CategoricalNB(BaseTransformer):
|
|
851
890
|
drop_input_cols = self._drop_input_cols,
|
852
891
|
expected_output_cols_type="float",
|
853
892
|
)
|
854
|
-
expected_output_cols = self.
|
893
|
+
expected_output_cols, _ = self._align_expected_output(
|
855
894
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
856
895
|
)
|
857
896
|
|