snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -540,12 +537,23 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
540
537
|
autogenerated=self._autogenerated,
|
541
538
|
subproject=_SUBPROJECT,
|
542
539
|
)
|
543
|
-
|
544
|
-
|
545
|
-
expected_output_cols_list=(
|
546
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
547
|
-
),
|
540
|
+
expected_output_cols = (
|
541
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
548
542
|
)
|
543
|
+
if isinstance(dataset, DataFrame):
|
544
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
545
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
546
|
+
)
|
547
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
548
|
+
drop_input_cols=self._drop_input_cols,
|
549
|
+
expected_output_cols_list=expected_output_cols,
|
550
|
+
example_output_pd_df=example_output_pd_df,
|
551
|
+
)
|
552
|
+
else:
|
553
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
554
|
+
drop_input_cols=self._drop_input_cols,
|
555
|
+
expected_output_cols_list=expected_output_cols,
|
556
|
+
)
|
549
557
|
self._sklearn_object = fitted_estimator
|
550
558
|
self._is_fitted = True
|
551
559
|
return output_result
|
@@ -624,12 +632,41 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
624
632
|
|
625
633
|
return rv
|
626
634
|
|
627
|
-
def
|
628
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
629
|
-
) -> List[str]:
|
635
|
+
def _align_expected_output(
|
636
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
637
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
638
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
639
|
+
and output dataframe with 1 line.
|
640
|
+
If the method is fit_predict, run 2 lines of data.
|
641
|
+
"""
|
630
642
|
# in case the inferred output column names dimension is different
|
631
643
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
632
|
-
|
644
|
+
|
645
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
646
|
+
# so change the minimum of number of rows to 2
|
647
|
+
num_examples = 2
|
648
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
649
|
+
project=_PROJECT,
|
650
|
+
subproject=_SUBPROJECT,
|
651
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
652
|
+
inspect.currentframe(), CalibratedClassifierCV.__class__.__name__
|
653
|
+
),
|
654
|
+
api_calls=[Session.call],
|
655
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
656
|
+
)
|
657
|
+
if output_cols_prefix == "fit_predict_":
|
658
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
659
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
660
|
+
num_examples = self._sklearn_object.n_clusters
|
661
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
662
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
663
|
+
num_examples = self._sklearn_object.min_samples
|
664
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
665
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
666
|
+
num_examples = self._sklearn_object.n_neighbors
|
667
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
668
|
+
else:
|
669
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
633
670
|
|
634
671
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
635
672
|
# seen during the fit.
|
@@ -641,12 +678,14 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
641
678
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
642
679
|
if self.sample_weight_col:
|
643
680
|
output_df_columns_set -= set(self.sample_weight_col)
|
681
|
+
|
644
682
|
# if the dimension of inferred output column names is correct; use it
|
645
683
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
646
|
-
return expected_output_cols_list
|
684
|
+
return expected_output_cols_list, output_df_pd
|
647
685
|
# otherwise, use the sklearn estimator's output
|
648
686
|
else:
|
649
|
-
|
687
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
688
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
650
689
|
|
651
690
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
652
691
|
@telemetry.send_api_usage_telemetry(
|
@@ -694,7 +733,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
694
733
|
drop_input_cols=self._drop_input_cols,
|
695
734
|
expected_output_cols_type="float",
|
696
735
|
)
|
697
|
-
expected_output_cols = self.
|
736
|
+
expected_output_cols, _ = self._align_expected_output(
|
698
737
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
699
738
|
)
|
700
739
|
|
@@ -762,7 +801,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
762
801
|
drop_input_cols=self._drop_input_cols,
|
763
802
|
expected_output_cols_type="float",
|
764
803
|
)
|
765
|
-
expected_output_cols = self.
|
804
|
+
expected_output_cols, _ = self._align_expected_output(
|
766
805
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
767
806
|
)
|
768
807
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -825,7 +864,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
825
864
|
drop_input_cols=self._drop_input_cols,
|
826
865
|
expected_output_cols_type="float",
|
827
866
|
)
|
828
|
-
expected_output_cols = self.
|
867
|
+
expected_output_cols, _ = self._align_expected_output(
|
829
868
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
830
869
|
)
|
831
870
|
|
@@ -890,7 +929,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
890
929
|
drop_input_cols = self._drop_input_cols,
|
891
930
|
expected_output_cols_type="float",
|
892
931
|
)
|
893
|
-
expected_output_cols = self.
|
932
|
+
expected_output_cols, _ = self._align_expected_output(
|
894
933
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
895
934
|
)
|
896
935
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -519,12 +516,23 @@ class AffinityPropagation(BaseTransformer):
|
|
519
516
|
autogenerated=self._autogenerated,
|
520
517
|
subproject=_SUBPROJECT,
|
521
518
|
)
|
522
|
-
|
523
|
-
|
524
|
-
expected_output_cols_list=(
|
525
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
526
|
-
),
|
519
|
+
expected_output_cols = (
|
520
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
527
521
|
)
|
522
|
+
if isinstance(dataset, DataFrame):
|
523
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
524
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
525
|
+
)
|
526
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
527
|
+
drop_input_cols=self._drop_input_cols,
|
528
|
+
expected_output_cols_list=expected_output_cols,
|
529
|
+
example_output_pd_df=example_output_pd_df,
|
530
|
+
)
|
531
|
+
else:
|
532
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
533
|
+
drop_input_cols=self._drop_input_cols,
|
534
|
+
expected_output_cols_list=expected_output_cols,
|
535
|
+
)
|
528
536
|
self._sklearn_object = fitted_estimator
|
529
537
|
self._is_fitted = True
|
530
538
|
return output_result
|
@@ -603,12 +611,41 @@ class AffinityPropagation(BaseTransformer):
|
|
603
611
|
|
604
612
|
return rv
|
605
613
|
|
606
|
-
def
|
607
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
608
|
-
) -> List[str]:
|
614
|
+
def _align_expected_output(
|
615
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
616
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
617
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
618
|
+
and output dataframe with 1 line.
|
619
|
+
If the method is fit_predict, run 2 lines of data.
|
620
|
+
"""
|
609
621
|
# in case the inferred output column names dimension is different
|
610
622
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
611
|
-
|
623
|
+
|
624
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
625
|
+
# so change the minimum of number of rows to 2
|
626
|
+
num_examples = 2
|
627
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
628
|
+
project=_PROJECT,
|
629
|
+
subproject=_SUBPROJECT,
|
630
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
631
|
+
inspect.currentframe(), AffinityPropagation.__class__.__name__
|
632
|
+
),
|
633
|
+
api_calls=[Session.call],
|
634
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
635
|
+
)
|
636
|
+
if output_cols_prefix == "fit_predict_":
|
637
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
638
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
639
|
+
num_examples = self._sklearn_object.n_clusters
|
640
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
641
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
642
|
+
num_examples = self._sklearn_object.min_samples
|
643
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
644
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
645
|
+
num_examples = self._sklearn_object.n_neighbors
|
646
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
647
|
+
else:
|
648
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
612
649
|
|
613
650
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
614
651
|
# seen during the fit.
|
@@ -620,12 +657,14 @@ class AffinityPropagation(BaseTransformer):
|
|
620
657
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
621
658
|
if self.sample_weight_col:
|
622
659
|
output_df_columns_set -= set(self.sample_weight_col)
|
660
|
+
|
623
661
|
# if the dimension of inferred output column names is correct; use it
|
624
662
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
625
|
-
return expected_output_cols_list
|
663
|
+
return expected_output_cols_list, output_df_pd
|
626
664
|
# otherwise, use the sklearn estimator's output
|
627
665
|
else:
|
628
|
-
|
666
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
667
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
629
668
|
|
630
669
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
631
670
|
@telemetry.send_api_usage_telemetry(
|
@@ -671,7 +710,7 @@ class AffinityPropagation(BaseTransformer):
|
|
671
710
|
drop_input_cols=self._drop_input_cols,
|
672
711
|
expected_output_cols_type="float",
|
673
712
|
)
|
674
|
-
expected_output_cols = self.
|
713
|
+
expected_output_cols, _ = self._align_expected_output(
|
675
714
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
676
715
|
)
|
677
716
|
|
@@ -737,7 +776,7 @@ class AffinityPropagation(BaseTransformer):
|
|
737
776
|
drop_input_cols=self._drop_input_cols,
|
738
777
|
expected_output_cols_type="float",
|
739
778
|
)
|
740
|
-
expected_output_cols = self.
|
779
|
+
expected_output_cols, _ = self._align_expected_output(
|
741
780
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
742
781
|
)
|
743
782
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -800,7 +839,7 @@ class AffinityPropagation(BaseTransformer):
|
|
800
839
|
drop_input_cols=self._drop_input_cols,
|
801
840
|
expected_output_cols_type="float",
|
802
841
|
)
|
803
|
-
expected_output_cols = self.
|
842
|
+
expected_output_cols, _ = self._align_expected_output(
|
804
843
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
805
844
|
)
|
806
845
|
|
@@ -865,7 +904,7 @@ class AffinityPropagation(BaseTransformer):
|
|
865
904
|
drop_input_cols = self._drop_input_cols,
|
866
905
|
expected_output_cols_type="float",
|
867
906
|
)
|
868
|
-
expected_output_cols = self.
|
907
|
+
expected_output_cols, _ = self._align_expected_output(
|
869
908
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
870
909
|
)
|
871
910
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -550,12 +547,23 @@ class AgglomerativeClustering(BaseTransformer):
|
|
550
547
|
autogenerated=self._autogenerated,
|
551
548
|
subproject=_SUBPROJECT,
|
552
549
|
)
|
553
|
-
|
554
|
-
|
555
|
-
expected_output_cols_list=(
|
556
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
557
|
-
),
|
550
|
+
expected_output_cols = (
|
551
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
558
552
|
)
|
553
|
+
if isinstance(dataset, DataFrame):
|
554
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
555
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
556
|
+
)
|
557
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
558
|
+
drop_input_cols=self._drop_input_cols,
|
559
|
+
expected_output_cols_list=expected_output_cols,
|
560
|
+
example_output_pd_df=example_output_pd_df,
|
561
|
+
)
|
562
|
+
else:
|
563
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
564
|
+
drop_input_cols=self._drop_input_cols,
|
565
|
+
expected_output_cols_list=expected_output_cols,
|
566
|
+
)
|
559
567
|
self._sklearn_object = fitted_estimator
|
560
568
|
self._is_fitted = True
|
561
569
|
return output_result
|
@@ -634,12 +642,41 @@ class AgglomerativeClustering(BaseTransformer):
|
|
634
642
|
|
635
643
|
return rv
|
636
644
|
|
637
|
-
def
|
638
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
639
|
-
) -> List[str]:
|
645
|
+
def _align_expected_output(
|
646
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
647
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
648
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
649
|
+
and output dataframe with 1 line.
|
650
|
+
If the method is fit_predict, run 2 lines of data.
|
651
|
+
"""
|
640
652
|
# in case the inferred output column names dimension is different
|
641
653
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
642
|
-
|
654
|
+
|
655
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
656
|
+
# so change the minimum of number of rows to 2
|
657
|
+
num_examples = 2
|
658
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
659
|
+
project=_PROJECT,
|
660
|
+
subproject=_SUBPROJECT,
|
661
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
662
|
+
inspect.currentframe(), AgglomerativeClustering.__class__.__name__
|
663
|
+
),
|
664
|
+
api_calls=[Session.call],
|
665
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
666
|
+
)
|
667
|
+
if output_cols_prefix == "fit_predict_":
|
668
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
669
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
670
|
+
num_examples = self._sklearn_object.n_clusters
|
671
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
672
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
673
|
+
num_examples = self._sklearn_object.min_samples
|
674
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
675
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
676
|
+
num_examples = self._sklearn_object.n_neighbors
|
677
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
678
|
+
else:
|
679
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
643
680
|
|
644
681
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
645
682
|
# seen during the fit.
|
@@ -651,12 +688,14 @@ class AgglomerativeClustering(BaseTransformer):
|
|
651
688
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
652
689
|
if self.sample_weight_col:
|
653
690
|
output_df_columns_set -= set(self.sample_weight_col)
|
691
|
+
|
654
692
|
# if the dimension of inferred output column names is correct; use it
|
655
693
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
656
|
-
return expected_output_cols_list
|
694
|
+
return expected_output_cols_list, output_df_pd
|
657
695
|
# otherwise, use the sklearn estimator's output
|
658
696
|
else:
|
659
|
-
|
697
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
698
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
660
699
|
|
661
700
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
662
701
|
@telemetry.send_api_usage_telemetry(
|
@@ -702,7 +741,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
702
741
|
drop_input_cols=self._drop_input_cols,
|
703
742
|
expected_output_cols_type="float",
|
704
743
|
)
|
705
|
-
expected_output_cols = self.
|
744
|
+
expected_output_cols, _ = self._align_expected_output(
|
706
745
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
707
746
|
)
|
708
747
|
|
@@ -768,7 +807,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
768
807
|
drop_input_cols=self._drop_input_cols,
|
769
808
|
expected_output_cols_type="float",
|
770
809
|
)
|
771
|
-
expected_output_cols = self.
|
810
|
+
expected_output_cols, _ = self._align_expected_output(
|
772
811
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
773
812
|
)
|
774
813
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -831,7 +870,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
831
870
|
drop_input_cols=self._drop_input_cols,
|
832
871
|
expected_output_cols_type="float",
|
833
872
|
)
|
834
|
-
expected_output_cols = self.
|
873
|
+
expected_output_cols, _ = self._align_expected_output(
|
835
874
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
836
875
|
)
|
837
876
|
|
@@ -896,7 +935,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
896
935
|
drop_input_cols = self._drop_input_cols,
|
897
936
|
expected_output_cols_type="float",
|
898
937
|
)
|
899
|
-
expected_output_cols = self.
|
938
|
+
expected_output_cols, _ = self._align_expected_output(
|
900
939
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
901
940
|
)
|
902
941
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -512,12 +509,23 @@ class Birch(BaseTransformer):
|
|
512
509
|
autogenerated=self._autogenerated,
|
513
510
|
subproject=_SUBPROJECT,
|
514
511
|
)
|
515
|
-
|
516
|
-
|
517
|
-
expected_output_cols_list=(
|
518
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
519
|
-
),
|
512
|
+
expected_output_cols = (
|
513
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
520
514
|
)
|
515
|
+
if isinstance(dataset, DataFrame):
|
516
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
517
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
518
|
+
)
|
519
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
520
|
+
drop_input_cols=self._drop_input_cols,
|
521
|
+
expected_output_cols_list=expected_output_cols,
|
522
|
+
example_output_pd_df=example_output_pd_df,
|
523
|
+
)
|
524
|
+
else:
|
525
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
526
|
+
drop_input_cols=self._drop_input_cols,
|
527
|
+
expected_output_cols_list=expected_output_cols,
|
528
|
+
)
|
521
529
|
self._sklearn_object = fitted_estimator
|
522
530
|
self._is_fitted = True
|
523
531
|
return output_result
|
@@ -598,12 +606,41 @@ class Birch(BaseTransformer):
|
|
598
606
|
|
599
607
|
return rv
|
600
608
|
|
601
|
-
def
|
602
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
603
|
-
) -> List[str]:
|
609
|
+
def _align_expected_output(
|
610
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
611
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
612
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
613
|
+
and output dataframe with 1 line.
|
614
|
+
If the method is fit_predict, run 2 lines of data.
|
615
|
+
"""
|
604
616
|
# in case the inferred output column names dimension is different
|
605
617
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
606
|
-
|
618
|
+
|
619
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
620
|
+
# so change the minimum of number of rows to 2
|
621
|
+
num_examples = 2
|
622
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
623
|
+
project=_PROJECT,
|
624
|
+
subproject=_SUBPROJECT,
|
625
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
626
|
+
inspect.currentframe(), Birch.__class__.__name__
|
627
|
+
),
|
628
|
+
api_calls=[Session.call],
|
629
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
630
|
+
)
|
631
|
+
if output_cols_prefix == "fit_predict_":
|
632
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
633
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
634
|
+
num_examples = self._sklearn_object.n_clusters
|
635
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
636
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
637
|
+
num_examples = self._sklearn_object.min_samples
|
638
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
639
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
640
|
+
num_examples = self._sklearn_object.n_neighbors
|
641
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
642
|
+
else:
|
643
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
607
644
|
|
608
645
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
609
646
|
# seen during the fit.
|
@@ -615,12 +652,14 @@ class Birch(BaseTransformer):
|
|
615
652
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
616
653
|
if self.sample_weight_col:
|
617
654
|
output_df_columns_set -= set(self.sample_weight_col)
|
655
|
+
|
618
656
|
# if the dimension of inferred output column names is correct; use it
|
619
657
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
620
|
-
return expected_output_cols_list
|
658
|
+
return expected_output_cols_list, output_df_pd
|
621
659
|
# otherwise, use the sklearn estimator's output
|
622
660
|
else:
|
623
|
-
|
661
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
662
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
624
663
|
|
625
664
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
626
665
|
@telemetry.send_api_usage_telemetry(
|
@@ -666,7 +705,7 @@ class Birch(BaseTransformer):
|
|
666
705
|
drop_input_cols=self._drop_input_cols,
|
667
706
|
expected_output_cols_type="float",
|
668
707
|
)
|
669
|
-
expected_output_cols = self.
|
708
|
+
expected_output_cols, _ = self._align_expected_output(
|
670
709
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
671
710
|
)
|
672
711
|
|
@@ -732,7 +771,7 @@ class Birch(BaseTransformer):
|
|
732
771
|
drop_input_cols=self._drop_input_cols,
|
733
772
|
expected_output_cols_type="float",
|
734
773
|
)
|
735
|
-
expected_output_cols = self.
|
774
|
+
expected_output_cols, _ = self._align_expected_output(
|
736
775
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
737
776
|
)
|
738
777
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -795,7 +834,7 @@ class Birch(BaseTransformer):
|
|
795
834
|
drop_input_cols=self._drop_input_cols,
|
796
835
|
expected_output_cols_type="float",
|
797
836
|
)
|
798
|
-
expected_output_cols = self.
|
837
|
+
expected_output_cols, _ = self._align_expected_output(
|
799
838
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
800
839
|
)
|
801
840
|
|
@@ -860,7 +899,7 @@ class Birch(BaseTransformer):
|
|
860
899
|
drop_input_cols = self._drop_input_cols,
|
861
900
|
expected_output_cols_type="float",
|
862
901
|
)
|
863
|
-
expected_output_cols = self.
|
902
|
+
expected_output_cols, _ = self._align_expected_output(
|
864
903
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
865
904
|
)
|
866
905
|
|