snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -540,12 +537,23 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
540
537
|
autogenerated=self._autogenerated,
|
541
538
|
subproject=_SUBPROJECT,
|
542
539
|
)
|
543
|
-
|
544
|
-
|
545
|
-
expected_output_cols_list=(
|
546
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
547
|
-
),
|
540
|
+
expected_output_cols = (
|
541
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
548
542
|
)
|
543
|
+
if isinstance(dataset, DataFrame):
|
544
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
545
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
546
|
+
)
|
547
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
548
|
+
drop_input_cols=self._drop_input_cols,
|
549
|
+
expected_output_cols_list=expected_output_cols,
|
550
|
+
example_output_pd_df=example_output_pd_df,
|
551
|
+
)
|
552
|
+
else:
|
553
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
554
|
+
drop_input_cols=self._drop_input_cols,
|
555
|
+
expected_output_cols_list=expected_output_cols,
|
556
|
+
)
|
549
557
|
self._sklearn_object = fitted_estimator
|
550
558
|
self._is_fitted = True
|
551
559
|
return output_result
|
@@ -568,6 +576,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
568
576
|
"""
|
569
577
|
self._infer_input_output_cols(dataset)
|
570
578
|
super()._check_dataset_type(dataset)
|
579
|
+
|
571
580
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
572
581
|
estimator=self._sklearn_object,
|
573
582
|
dataset=dataset,
|
@@ -624,12 +633,41 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
624
633
|
|
625
634
|
return rv
|
626
635
|
|
627
|
-
def
|
628
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
629
|
-
) -> List[str]:
|
636
|
+
def _align_expected_output(
|
637
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
638
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
639
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
640
|
+
and output dataframe with 1 line.
|
641
|
+
If the method is fit_predict, run 2 lines of data.
|
642
|
+
"""
|
630
643
|
# in case the inferred output column names dimension is different
|
631
644
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
632
|
-
|
645
|
+
|
646
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
647
|
+
# so change the minimum of number of rows to 2
|
648
|
+
num_examples = 2
|
649
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
650
|
+
project=_PROJECT,
|
651
|
+
subproject=_SUBPROJECT,
|
652
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
653
|
+
inspect.currentframe(), CalibratedClassifierCV.__class__.__name__
|
654
|
+
),
|
655
|
+
api_calls=[Session.call],
|
656
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
657
|
+
)
|
658
|
+
if output_cols_prefix == "fit_predict_":
|
659
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
660
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
661
|
+
num_examples = self._sklearn_object.n_clusters
|
662
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
663
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
664
|
+
num_examples = self._sklearn_object.min_samples
|
665
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
666
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
667
|
+
num_examples = self._sklearn_object.n_neighbors
|
668
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
669
|
+
else:
|
670
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
633
671
|
|
634
672
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
635
673
|
# seen during the fit.
|
@@ -641,12 +679,14 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
641
679
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
642
680
|
if self.sample_weight_col:
|
643
681
|
output_df_columns_set -= set(self.sample_weight_col)
|
682
|
+
|
644
683
|
# if the dimension of inferred output column names is correct; use it
|
645
684
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
646
|
-
return expected_output_cols_list
|
685
|
+
return expected_output_cols_list, output_df_pd
|
647
686
|
# otherwise, use the sklearn estimator's output
|
648
687
|
else:
|
649
|
-
|
688
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
689
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
650
690
|
|
651
691
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
652
692
|
@telemetry.send_api_usage_telemetry(
|
@@ -694,7 +734,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
694
734
|
drop_input_cols=self._drop_input_cols,
|
695
735
|
expected_output_cols_type="float",
|
696
736
|
)
|
697
|
-
expected_output_cols = self.
|
737
|
+
expected_output_cols, _ = self._align_expected_output(
|
698
738
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
699
739
|
)
|
700
740
|
|
@@ -762,7 +802,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
762
802
|
drop_input_cols=self._drop_input_cols,
|
763
803
|
expected_output_cols_type="float",
|
764
804
|
)
|
765
|
-
expected_output_cols = self.
|
805
|
+
expected_output_cols, _ = self._align_expected_output(
|
766
806
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
767
807
|
)
|
768
808
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -825,7 +865,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
825
865
|
drop_input_cols=self._drop_input_cols,
|
826
866
|
expected_output_cols_type="float",
|
827
867
|
)
|
828
|
-
expected_output_cols = self.
|
868
|
+
expected_output_cols, _ = self._align_expected_output(
|
829
869
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
830
870
|
)
|
831
871
|
|
@@ -890,7 +930,7 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
890
930
|
drop_input_cols = self._drop_input_cols,
|
891
931
|
expected_output_cols_type="float",
|
892
932
|
)
|
893
|
-
expected_output_cols = self.
|
933
|
+
expected_output_cols, _ = self._align_expected_output(
|
894
934
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
895
935
|
)
|
896
936
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -519,12 +516,23 @@ class AffinityPropagation(BaseTransformer):
|
|
519
516
|
autogenerated=self._autogenerated,
|
520
517
|
subproject=_SUBPROJECT,
|
521
518
|
)
|
522
|
-
|
523
|
-
|
524
|
-
expected_output_cols_list=(
|
525
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
526
|
-
),
|
519
|
+
expected_output_cols = (
|
520
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
527
521
|
)
|
522
|
+
if isinstance(dataset, DataFrame):
|
523
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
524
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
525
|
+
)
|
526
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
527
|
+
drop_input_cols=self._drop_input_cols,
|
528
|
+
expected_output_cols_list=expected_output_cols,
|
529
|
+
example_output_pd_df=example_output_pd_df,
|
530
|
+
)
|
531
|
+
else:
|
532
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
533
|
+
drop_input_cols=self._drop_input_cols,
|
534
|
+
expected_output_cols_list=expected_output_cols,
|
535
|
+
)
|
528
536
|
self._sklearn_object = fitted_estimator
|
529
537
|
self._is_fitted = True
|
530
538
|
return output_result
|
@@ -547,6 +555,7 @@ class AffinityPropagation(BaseTransformer):
|
|
547
555
|
"""
|
548
556
|
self._infer_input_output_cols(dataset)
|
549
557
|
super()._check_dataset_type(dataset)
|
558
|
+
|
550
559
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
551
560
|
estimator=self._sklearn_object,
|
552
561
|
dataset=dataset,
|
@@ -603,12 +612,41 @@ class AffinityPropagation(BaseTransformer):
|
|
603
612
|
|
604
613
|
return rv
|
605
614
|
|
606
|
-
def
|
607
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
608
|
-
) -> List[str]:
|
615
|
+
def _align_expected_output(
|
616
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
617
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
618
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
619
|
+
and output dataframe with 1 line.
|
620
|
+
If the method is fit_predict, run 2 lines of data.
|
621
|
+
"""
|
609
622
|
# in case the inferred output column names dimension is different
|
610
623
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
611
|
-
|
624
|
+
|
625
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
626
|
+
# so change the minimum of number of rows to 2
|
627
|
+
num_examples = 2
|
628
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
629
|
+
project=_PROJECT,
|
630
|
+
subproject=_SUBPROJECT,
|
631
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
632
|
+
inspect.currentframe(), AffinityPropagation.__class__.__name__
|
633
|
+
),
|
634
|
+
api_calls=[Session.call],
|
635
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
636
|
+
)
|
637
|
+
if output_cols_prefix == "fit_predict_":
|
638
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
639
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
640
|
+
num_examples = self._sklearn_object.n_clusters
|
641
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
642
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
643
|
+
num_examples = self._sklearn_object.min_samples
|
644
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
645
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
646
|
+
num_examples = self._sklearn_object.n_neighbors
|
647
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
648
|
+
else:
|
649
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
612
650
|
|
613
651
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
614
652
|
# seen during the fit.
|
@@ -620,12 +658,14 @@ class AffinityPropagation(BaseTransformer):
|
|
620
658
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
621
659
|
if self.sample_weight_col:
|
622
660
|
output_df_columns_set -= set(self.sample_weight_col)
|
661
|
+
|
623
662
|
# if the dimension of inferred output column names is correct; use it
|
624
663
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
625
|
-
return expected_output_cols_list
|
664
|
+
return expected_output_cols_list, output_df_pd
|
626
665
|
# otherwise, use the sklearn estimator's output
|
627
666
|
else:
|
628
|
-
|
667
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
668
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
629
669
|
|
630
670
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
631
671
|
@telemetry.send_api_usage_telemetry(
|
@@ -671,7 +711,7 @@ class AffinityPropagation(BaseTransformer):
|
|
671
711
|
drop_input_cols=self._drop_input_cols,
|
672
712
|
expected_output_cols_type="float",
|
673
713
|
)
|
674
|
-
expected_output_cols = self.
|
714
|
+
expected_output_cols, _ = self._align_expected_output(
|
675
715
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
676
716
|
)
|
677
717
|
|
@@ -737,7 +777,7 @@ class AffinityPropagation(BaseTransformer):
|
|
737
777
|
drop_input_cols=self._drop_input_cols,
|
738
778
|
expected_output_cols_type="float",
|
739
779
|
)
|
740
|
-
expected_output_cols = self.
|
780
|
+
expected_output_cols, _ = self._align_expected_output(
|
741
781
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
742
782
|
)
|
743
783
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -800,7 +840,7 @@ class AffinityPropagation(BaseTransformer):
|
|
800
840
|
drop_input_cols=self._drop_input_cols,
|
801
841
|
expected_output_cols_type="float",
|
802
842
|
)
|
803
|
-
expected_output_cols = self.
|
843
|
+
expected_output_cols, _ = self._align_expected_output(
|
804
844
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
805
845
|
)
|
806
846
|
|
@@ -865,7 +905,7 @@ class AffinityPropagation(BaseTransformer):
|
|
865
905
|
drop_input_cols = self._drop_input_cols,
|
866
906
|
expected_output_cols_type="float",
|
867
907
|
)
|
868
|
-
expected_output_cols = self.
|
908
|
+
expected_output_cols, _ = self._align_expected_output(
|
869
909
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
870
910
|
)
|
871
911
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -550,12 +547,23 @@ class AgglomerativeClustering(BaseTransformer):
|
|
550
547
|
autogenerated=self._autogenerated,
|
551
548
|
subproject=_SUBPROJECT,
|
552
549
|
)
|
553
|
-
|
554
|
-
|
555
|
-
expected_output_cols_list=(
|
556
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
557
|
-
),
|
550
|
+
expected_output_cols = (
|
551
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
558
552
|
)
|
553
|
+
if isinstance(dataset, DataFrame):
|
554
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
555
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
556
|
+
)
|
557
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
558
|
+
drop_input_cols=self._drop_input_cols,
|
559
|
+
expected_output_cols_list=expected_output_cols,
|
560
|
+
example_output_pd_df=example_output_pd_df,
|
561
|
+
)
|
562
|
+
else:
|
563
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
564
|
+
drop_input_cols=self._drop_input_cols,
|
565
|
+
expected_output_cols_list=expected_output_cols,
|
566
|
+
)
|
559
567
|
self._sklearn_object = fitted_estimator
|
560
568
|
self._is_fitted = True
|
561
569
|
return output_result
|
@@ -578,6 +586,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
578
586
|
"""
|
579
587
|
self._infer_input_output_cols(dataset)
|
580
588
|
super()._check_dataset_type(dataset)
|
589
|
+
|
581
590
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
582
591
|
estimator=self._sklearn_object,
|
583
592
|
dataset=dataset,
|
@@ -634,12 +643,41 @@ class AgglomerativeClustering(BaseTransformer):
|
|
634
643
|
|
635
644
|
return rv
|
636
645
|
|
637
|
-
def
|
638
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
639
|
-
) -> List[str]:
|
646
|
+
def _align_expected_output(
|
647
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
648
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
649
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
650
|
+
and output dataframe with 1 line.
|
651
|
+
If the method is fit_predict, run 2 lines of data.
|
652
|
+
"""
|
640
653
|
# in case the inferred output column names dimension is different
|
641
654
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
642
|
-
|
655
|
+
|
656
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
657
|
+
# so change the minimum of number of rows to 2
|
658
|
+
num_examples = 2
|
659
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
660
|
+
project=_PROJECT,
|
661
|
+
subproject=_SUBPROJECT,
|
662
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
663
|
+
inspect.currentframe(), AgglomerativeClustering.__class__.__name__
|
664
|
+
),
|
665
|
+
api_calls=[Session.call],
|
666
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
667
|
+
)
|
668
|
+
if output_cols_prefix == "fit_predict_":
|
669
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
670
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
671
|
+
num_examples = self._sklearn_object.n_clusters
|
672
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
673
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
674
|
+
num_examples = self._sklearn_object.min_samples
|
675
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
676
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
677
|
+
num_examples = self._sklearn_object.n_neighbors
|
678
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
679
|
+
else:
|
680
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
643
681
|
|
644
682
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
645
683
|
# seen during the fit.
|
@@ -651,12 +689,14 @@ class AgglomerativeClustering(BaseTransformer):
|
|
651
689
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
652
690
|
if self.sample_weight_col:
|
653
691
|
output_df_columns_set -= set(self.sample_weight_col)
|
692
|
+
|
654
693
|
# if the dimension of inferred output column names is correct; use it
|
655
694
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
656
|
-
return expected_output_cols_list
|
695
|
+
return expected_output_cols_list, output_df_pd
|
657
696
|
# otherwise, use the sklearn estimator's output
|
658
697
|
else:
|
659
|
-
|
698
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
699
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
660
700
|
|
661
701
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
662
702
|
@telemetry.send_api_usage_telemetry(
|
@@ -702,7 +742,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
702
742
|
drop_input_cols=self._drop_input_cols,
|
703
743
|
expected_output_cols_type="float",
|
704
744
|
)
|
705
|
-
expected_output_cols = self.
|
745
|
+
expected_output_cols, _ = self._align_expected_output(
|
706
746
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
707
747
|
)
|
708
748
|
|
@@ -768,7 +808,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
768
808
|
drop_input_cols=self._drop_input_cols,
|
769
809
|
expected_output_cols_type="float",
|
770
810
|
)
|
771
|
-
expected_output_cols = self.
|
811
|
+
expected_output_cols, _ = self._align_expected_output(
|
772
812
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
773
813
|
)
|
774
814
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -831,7 +871,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
831
871
|
drop_input_cols=self._drop_input_cols,
|
832
872
|
expected_output_cols_type="float",
|
833
873
|
)
|
834
|
-
expected_output_cols = self.
|
874
|
+
expected_output_cols, _ = self._align_expected_output(
|
835
875
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
836
876
|
)
|
837
877
|
|
@@ -896,7 +936,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
896
936
|
drop_input_cols = self._drop_input_cols,
|
897
937
|
expected_output_cols_type="float",
|
898
938
|
)
|
899
|
-
expected_output_cols = self.
|
939
|
+
expected_output_cols, _ = self._align_expected_output(
|
900
940
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
901
941
|
)
|
902
942
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -512,12 +509,23 @@ class Birch(BaseTransformer):
|
|
512
509
|
autogenerated=self._autogenerated,
|
513
510
|
subproject=_SUBPROJECT,
|
514
511
|
)
|
515
|
-
|
516
|
-
|
517
|
-
expected_output_cols_list=(
|
518
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
519
|
-
),
|
512
|
+
expected_output_cols = (
|
513
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
520
514
|
)
|
515
|
+
if isinstance(dataset, DataFrame):
|
516
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
517
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
518
|
+
)
|
519
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
520
|
+
drop_input_cols=self._drop_input_cols,
|
521
|
+
expected_output_cols_list=expected_output_cols,
|
522
|
+
example_output_pd_df=example_output_pd_df,
|
523
|
+
)
|
524
|
+
else:
|
525
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
526
|
+
drop_input_cols=self._drop_input_cols,
|
527
|
+
expected_output_cols_list=expected_output_cols,
|
528
|
+
)
|
521
529
|
self._sklearn_object = fitted_estimator
|
522
530
|
self._is_fitted = True
|
523
531
|
return output_result
|
@@ -542,6 +550,7 @@ class Birch(BaseTransformer):
|
|
542
550
|
"""
|
543
551
|
self._infer_input_output_cols(dataset)
|
544
552
|
super()._check_dataset_type(dataset)
|
553
|
+
|
545
554
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
546
555
|
estimator=self._sklearn_object,
|
547
556
|
dataset=dataset,
|
@@ -598,12 +607,41 @@ class Birch(BaseTransformer):
|
|
598
607
|
|
599
608
|
return rv
|
600
609
|
|
601
|
-
def
|
602
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
603
|
-
) -> List[str]:
|
610
|
+
def _align_expected_output(
|
611
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
612
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
613
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
614
|
+
and output dataframe with 1 line.
|
615
|
+
If the method is fit_predict, run 2 lines of data.
|
616
|
+
"""
|
604
617
|
# in case the inferred output column names dimension is different
|
605
618
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
606
|
-
|
619
|
+
|
620
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
621
|
+
# so change the minimum of number of rows to 2
|
622
|
+
num_examples = 2
|
623
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
624
|
+
project=_PROJECT,
|
625
|
+
subproject=_SUBPROJECT,
|
626
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
627
|
+
inspect.currentframe(), Birch.__class__.__name__
|
628
|
+
),
|
629
|
+
api_calls=[Session.call],
|
630
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
631
|
+
)
|
632
|
+
if output_cols_prefix == "fit_predict_":
|
633
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
634
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
635
|
+
num_examples = self._sklearn_object.n_clusters
|
636
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
637
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
638
|
+
num_examples = self._sklearn_object.min_samples
|
639
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
640
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
641
|
+
num_examples = self._sklearn_object.n_neighbors
|
642
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
643
|
+
else:
|
644
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
607
645
|
|
608
646
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
609
647
|
# seen during the fit.
|
@@ -615,12 +653,14 @@ class Birch(BaseTransformer):
|
|
615
653
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
616
654
|
if self.sample_weight_col:
|
617
655
|
output_df_columns_set -= set(self.sample_weight_col)
|
656
|
+
|
618
657
|
# if the dimension of inferred output column names is correct; use it
|
619
658
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
620
|
-
return expected_output_cols_list
|
659
|
+
return expected_output_cols_list, output_df_pd
|
621
660
|
# otherwise, use the sklearn estimator's output
|
622
661
|
else:
|
623
|
-
|
662
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
663
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
624
664
|
|
625
665
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
626
666
|
@telemetry.send_api_usage_telemetry(
|
@@ -666,7 +706,7 @@ class Birch(BaseTransformer):
|
|
666
706
|
drop_input_cols=self._drop_input_cols,
|
667
707
|
expected_output_cols_type="float",
|
668
708
|
)
|
669
|
-
expected_output_cols = self.
|
709
|
+
expected_output_cols, _ = self._align_expected_output(
|
670
710
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
671
711
|
)
|
672
712
|
|
@@ -732,7 +772,7 @@ class Birch(BaseTransformer):
|
|
732
772
|
drop_input_cols=self._drop_input_cols,
|
733
773
|
expected_output_cols_type="float",
|
734
774
|
)
|
735
|
-
expected_output_cols = self.
|
775
|
+
expected_output_cols, _ = self._align_expected_output(
|
736
776
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
737
777
|
)
|
738
778
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -795,7 +835,7 @@ class Birch(BaseTransformer):
|
|
795
835
|
drop_input_cols=self._drop_input_cols,
|
796
836
|
expected_output_cols_type="float",
|
797
837
|
)
|
798
|
-
expected_output_cols = self.
|
838
|
+
expected_output_cols, _ = self._align_expected_output(
|
799
839
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
800
840
|
)
|
801
841
|
|
@@ -860,7 +900,7 @@ class Birch(BaseTransformer):
|
|
860
900
|
drop_input_cols = self._drop_input_cols,
|
861
901
|
expected_output_cols_type="float",
|
862
902
|
)
|
863
|
-
expected_output_cols = self.
|
903
|
+
expected_output_cols, _ = self._align_expected_output(
|
864
904
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
865
905
|
)
|
866
906
|
|