snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -530,12 +527,23 @@ class MeanShift(BaseTransformer):
|
|
530
527
|
autogenerated=self._autogenerated,
|
531
528
|
subproject=_SUBPROJECT,
|
532
529
|
)
|
533
|
-
|
534
|
-
|
535
|
-
expected_output_cols_list=(
|
536
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
537
|
-
),
|
530
|
+
expected_output_cols = (
|
531
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
538
532
|
)
|
533
|
+
if isinstance(dataset, DataFrame):
|
534
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
535
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
536
|
+
)
|
537
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
538
|
+
drop_input_cols=self._drop_input_cols,
|
539
|
+
expected_output_cols_list=expected_output_cols,
|
540
|
+
example_output_pd_df=example_output_pd_df,
|
541
|
+
)
|
542
|
+
else:
|
543
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
544
|
+
drop_input_cols=self._drop_input_cols,
|
545
|
+
expected_output_cols_list=expected_output_cols,
|
546
|
+
)
|
539
547
|
self._sklearn_object = fitted_estimator
|
540
548
|
self._is_fitted = True
|
541
549
|
return output_result
|
@@ -558,6 +566,7 @@ class MeanShift(BaseTransformer):
|
|
558
566
|
"""
|
559
567
|
self._infer_input_output_cols(dataset)
|
560
568
|
super()._check_dataset_type(dataset)
|
569
|
+
|
561
570
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
562
571
|
estimator=self._sklearn_object,
|
563
572
|
dataset=dataset,
|
@@ -614,12 +623,41 @@ class MeanShift(BaseTransformer):
|
|
614
623
|
|
615
624
|
return rv
|
616
625
|
|
617
|
-
def
|
618
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
619
|
-
) -> List[str]:
|
626
|
+
def _align_expected_output(
|
627
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
628
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
629
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
630
|
+
and output dataframe with 1 line.
|
631
|
+
If the method is fit_predict, run 2 lines of data.
|
632
|
+
"""
|
620
633
|
# in case the inferred output column names dimension is different
|
621
634
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
622
|
-
|
635
|
+
|
636
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
637
|
+
# so change the minimum of number of rows to 2
|
638
|
+
num_examples = 2
|
639
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
640
|
+
project=_PROJECT,
|
641
|
+
subproject=_SUBPROJECT,
|
642
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
643
|
+
inspect.currentframe(), MeanShift.__class__.__name__
|
644
|
+
),
|
645
|
+
api_calls=[Session.call],
|
646
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
647
|
+
)
|
648
|
+
if output_cols_prefix == "fit_predict_":
|
649
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
650
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
651
|
+
num_examples = self._sklearn_object.n_clusters
|
652
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
653
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
654
|
+
num_examples = self._sklearn_object.min_samples
|
655
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
656
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
657
|
+
num_examples = self._sklearn_object.n_neighbors
|
658
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
659
|
+
else:
|
660
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
623
661
|
|
624
662
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
625
663
|
# seen during the fit.
|
@@ -631,12 +669,14 @@ class MeanShift(BaseTransformer):
|
|
631
669
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
632
670
|
if self.sample_weight_col:
|
633
671
|
output_df_columns_set -= set(self.sample_weight_col)
|
672
|
+
|
634
673
|
# if the dimension of inferred output column names is correct; use it
|
635
674
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
636
|
-
return expected_output_cols_list
|
675
|
+
return expected_output_cols_list, output_df_pd
|
637
676
|
# otherwise, use the sklearn estimator's output
|
638
677
|
else:
|
639
|
-
|
678
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
679
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
640
680
|
|
641
681
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
642
682
|
@telemetry.send_api_usage_telemetry(
|
@@ -682,7 +722,7 @@ class MeanShift(BaseTransformer):
|
|
682
722
|
drop_input_cols=self._drop_input_cols,
|
683
723
|
expected_output_cols_type="float",
|
684
724
|
)
|
685
|
-
expected_output_cols = self.
|
725
|
+
expected_output_cols, _ = self._align_expected_output(
|
686
726
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
687
727
|
)
|
688
728
|
|
@@ -748,7 +788,7 @@ class MeanShift(BaseTransformer):
|
|
748
788
|
drop_input_cols=self._drop_input_cols,
|
749
789
|
expected_output_cols_type="float",
|
750
790
|
)
|
751
|
-
expected_output_cols = self.
|
791
|
+
expected_output_cols, _ = self._align_expected_output(
|
752
792
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
753
793
|
)
|
754
794
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -811,7 +851,7 @@ class MeanShift(BaseTransformer):
|
|
811
851
|
drop_input_cols=self._drop_input_cols,
|
812
852
|
expected_output_cols_type="float",
|
813
853
|
)
|
814
|
-
expected_output_cols = self.
|
854
|
+
expected_output_cols, _ = self._align_expected_output(
|
815
855
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
816
856
|
)
|
817
857
|
|
@@ -876,7 +916,7 @@ class MeanShift(BaseTransformer):
|
|
876
916
|
drop_input_cols = self._drop_input_cols,
|
877
917
|
expected_output_cols_type="float",
|
878
918
|
)
|
879
|
-
expected_output_cols = self.
|
919
|
+
expected_output_cols, _ = self._align_expected_output(
|
880
920
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
881
921
|
)
|
882
922
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -582,12 +579,23 @@ class MiniBatchKMeans(BaseTransformer):
|
|
582
579
|
autogenerated=self._autogenerated,
|
583
580
|
subproject=_SUBPROJECT,
|
584
581
|
)
|
585
|
-
|
586
|
-
|
587
|
-
expected_output_cols_list=(
|
588
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
589
|
-
),
|
582
|
+
expected_output_cols = (
|
583
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
590
584
|
)
|
585
|
+
if isinstance(dataset, DataFrame):
|
586
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
587
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
588
|
+
)
|
589
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
590
|
+
drop_input_cols=self._drop_input_cols,
|
591
|
+
expected_output_cols_list=expected_output_cols,
|
592
|
+
example_output_pd_df=example_output_pd_df,
|
593
|
+
)
|
594
|
+
else:
|
595
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
596
|
+
drop_input_cols=self._drop_input_cols,
|
597
|
+
expected_output_cols_list=expected_output_cols,
|
598
|
+
)
|
591
599
|
self._sklearn_object = fitted_estimator
|
592
600
|
self._is_fitted = True
|
593
601
|
return output_result
|
@@ -612,6 +620,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
612
620
|
"""
|
613
621
|
self._infer_input_output_cols(dataset)
|
614
622
|
super()._check_dataset_type(dataset)
|
623
|
+
|
615
624
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
616
625
|
estimator=self._sklearn_object,
|
617
626
|
dataset=dataset,
|
@@ -668,12 +677,41 @@ class MiniBatchKMeans(BaseTransformer):
|
|
668
677
|
|
669
678
|
return rv
|
670
679
|
|
671
|
-
def
|
672
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
673
|
-
) -> List[str]:
|
680
|
+
def _align_expected_output(
|
681
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
682
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
683
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
684
|
+
and output dataframe with 1 line.
|
685
|
+
If the method is fit_predict, run 2 lines of data.
|
686
|
+
"""
|
674
687
|
# in case the inferred output column names dimension is different
|
675
688
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
676
|
-
|
689
|
+
|
690
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
691
|
+
# so change the minimum of number of rows to 2
|
692
|
+
num_examples = 2
|
693
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
694
|
+
project=_PROJECT,
|
695
|
+
subproject=_SUBPROJECT,
|
696
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
697
|
+
inspect.currentframe(), MiniBatchKMeans.__class__.__name__
|
698
|
+
),
|
699
|
+
api_calls=[Session.call],
|
700
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
701
|
+
)
|
702
|
+
if output_cols_prefix == "fit_predict_":
|
703
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
704
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
705
|
+
num_examples = self._sklearn_object.n_clusters
|
706
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
707
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
708
|
+
num_examples = self._sklearn_object.min_samples
|
709
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
710
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
711
|
+
num_examples = self._sklearn_object.n_neighbors
|
712
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
713
|
+
else:
|
714
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
677
715
|
|
678
716
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
679
717
|
# seen during the fit.
|
@@ -685,12 +723,14 @@ class MiniBatchKMeans(BaseTransformer):
|
|
685
723
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
686
724
|
if self.sample_weight_col:
|
687
725
|
output_df_columns_set -= set(self.sample_weight_col)
|
726
|
+
|
688
727
|
# if the dimension of inferred output column names is correct; use it
|
689
728
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
690
|
-
return expected_output_cols_list
|
729
|
+
return expected_output_cols_list, output_df_pd
|
691
730
|
# otherwise, use the sklearn estimator's output
|
692
731
|
else:
|
693
|
-
|
732
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
733
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
694
734
|
|
695
735
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
696
736
|
@telemetry.send_api_usage_telemetry(
|
@@ -736,7 +776,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
736
776
|
drop_input_cols=self._drop_input_cols,
|
737
777
|
expected_output_cols_type="float",
|
738
778
|
)
|
739
|
-
expected_output_cols = self.
|
779
|
+
expected_output_cols, _ = self._align_expected_output(
|
740
780
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
741
781
|
)
|
742
782
|
|
@@ -802,7 +842,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
802
842
|
drop_input_cols=self._drop_input_cols,
|
803
843
|
expected_output_cols_type="float",
|
804
844
|
)
|
805
|
-
expected_output_cols = self.
|
845
|
+
expected_output_cols, _ = self._align_expected_output(
|
806
846
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
807
847
|
)
|
808
848
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -865,7 +905,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
865
905
|
drop_input_cols=self._drop_input_cols,
|
866
906
|
expected_output_cols_type="float",
|
867
907
|
)
|
868
|
-
expected_output_cols = self.
|
908
|
+
expected_output_cols, _ = self._align_expected_output(
|
869
909
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
870
910
|
)
|
871
911
|
|
@@ -930,7 +970,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
930
970
|
drop_input_cols = self._drop_input_cols,
|
931
971
|
expected_output_cols_type="float",
|
932
972
|
)
|
933
|
-
expected_output_cols = self.
|
973
|
+
expected_output_cols, _ = self._align_expected_output(
|
934
974
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
935
975
|
)
|
936
976
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -598,12 +595,23 @@ class OPTICS(BaseTransformer):
|
|
598
595
|
autogenerated=self._autogenerated,
|
599
596
|
subproject=_SUBPROJECT,
|
600
597
|
)
|
601
|
-
|
602
|
-
|
603
|
-
expected_output_cols_list=(
|
604
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
605
|
-
),
|
598
|
+
expected_output_cols = (
|
599
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
606
600
|
)
|
601
|
+
if isinstance(dataset, DataFrame):
|
602
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
603
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
604
|
+
)
|
605
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
606
|
+
drop_input_cols=self._drop_input_cols,
|
607
|
+
expected_output_cols_list=expected_output_cols,
|
608
|
+
example_output_pd_df=example_output_pd_df,
|
609
|
+
)
|
610
|
+
else:
|
611
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
612
|
+
drop_input_cols=self._drop_input_cols,
|
613
|
+
expected_output_cols_list=expected_output_cols,
|
614
|
+
)
|
607
615
|
self._sklearn_object = fitted_estimator
|
608
616
|
self._is_fitted = True
|
609
617
|
return output_result
|
@@ -626,6 +634,7 @@ class OPTICS(BaseTransformer):
|
|
626
634
|
"""
|
627
635
|
self._infer_input_output_cols(dataset)
|
628
636
|
super()._check_dataset_type(dataset)
|
637
|
+
|
629
638
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
630
639
|
estimator=self._sklearn_object,
|
631
640
|
dataset=dataset,
|
@@ -682,12 +691,41 @@ class OPTICS(BaseTransformer):
|
|
682
691
|
|
683
692
|
return rv
|
684
693
|
|
685
|
-
def
|
686
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
687
|
-
) -> List[str]:
|
694
|
+
def _align_expected_output(
|
695
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
696
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
697
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
698
|
+
and output dataframe with 1 line.
|
699
|
+
If the method is fit_predict, run 2 lines of data.
|
700
|
+
"""
|
688
701
|
# in case the inferred output column names dimension is different
|
689
702
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
690
|
-
|
703
|
+
|
704
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
705
|
+
# so change the minimum of number of rows to 2
|
706
|
+
num_examples = 2
|
707
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
708
|
+
project=_PROJECT,
|
709
|
+
subproject=_SUBPROJECT,
|
710
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
711
|
+
inspect.currentframe(), OPTICS.__class__.__name__
|
712
|
+
),
|
713
|
+
api_calls=[Session.call],
|
714
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
715
|
+
)
|
716
|
+
if output_cols_prefix == "fit_predict_":
|
717
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
718
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
719
|
+
num_examples = self._sklearn_object.n_clusters
|
720
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
721
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
722
|
+
num_examples = self._sklearn_object.min_samples
|
723
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
724
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
725
|
+
num_examples = self._sklearn_object.n_neighbors
|
726
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
727
|
+
else:
|
728
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
691
729
|
|
692
730
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
693
731
|
# seen during the fit.
|
@@ -699,12 +737,14 @@ class OPTICS(BaseTransformer):
|
|
699
737
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
700
738
|
if self.sample_weight_col:
|
701
739
|
output_df_columns_set -= set(self.sample_weight_col)
|
740
|
+
|
702
741
|
# if the dimension of inferred output column names is correct; use it
|
703
742
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
704
|
-
return expected_output_cols_list
|
743
|
+
return expected_output_cols_list, output_df_pd
|
705
744
|
# otherwise, use the sklearn estimator's output
|
706
745
|
else:
|
707
|
-
|
746
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
747
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
708
748
|
|
709
749
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
710
750
|
@telemetry.send_api_usage_telemetry(
|
@@ -750,7 +790,7 @@ class OPTICS(BaseTransformer):
|
|
750
790
|
drop_input_cols=self._drop_input_cols,
|
751
791
|
expected_output_cols_type="float",
|
752
792
|
)
|
753
|
-
expected_output_cols = self.
|
793
|
+
expected_output_cols, _ = self._align_expected_output(
|
754
794
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
755
795
|
)
|
756
796
|
|
@@ -816,7 +856,7 @@ class OPTICS(BaseTransformer):
|
|
816
856
|
drop_input_cols=self._drop_input_cols,
|
817
857
|
expected_output_cols_type="float",
|
818
858
|
)
|
819
|
-
expected_output_cols = self.
|
859
|
+
expected_output_cols, _ = self._align_expected_output(
|
820
860
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
821
861
|
)
|
822
862
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -879,7 +919,7 @@ class OPTICS(BaseTransformer):
|
|
879
919
|
drop_input_cols=self._drop_input_cols,
|
880
920
|
expected_output_cols_type="float",
|
881
921
|
)
|
882
|
-
expected_output_cols = self.
|
922
|
+
expected_output_cols, _ = self._align_expected_output(
|
883
923
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
884
924
|
)
|
885
925
|
|
@@ -944,7 +984,7 @@ class OPTICS(BaseTransformer):
|
|
944
984
|
drop_input_cols = self._drop_input_cols,
|
945
985
|
expected_output_cols_type="float",
|
946
986
|
)
|
947
|
-
expected_output_cols = self.
|
987
|
+
expected_output_cols, _ = self._align_expected_output(
|
948
988
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
949
989
|
)
|
950
990
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -534,12 +531,23 @@ class SpectralBiclustering(BaseTransformer):
|
|
534
531
|
autogenerated=self._autogenerated,
|
535
532
|
subproject=_SUBPROJECT,
|
536
533
|
)
|
537
|
-
|
538
|
-
|
539
|
-
expected_output_cols_list=(
|
540
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
541
|
-
),
|
534
|
+
expected_output_cols = (
|
535
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
542
536
|
)
|
537
|
+
if isinstance(dataset, DataFrame):
|
538
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
539
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
540
|
+
)
|
541
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
542
|
+
drop_input_cols=self._drop_input_cols,
|
543
|
+
expected_output_cols_list=expected_output_cols,
|
544
|
+
example_output_pd_df=example_output_pd_df,
|
545
|
+
)
|
546
|
+
else:
|
547
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
548
|
+
drop_input_cols=self._drop_input_cols,
|
549
|
+
expected_output_cols_list=expected_output_cols,
|
550
|
+
)
|
543
551
|
self._sklearn_object = fitted_estimator
|
544
552
|
self._is_fitted = True
|
545
553
|
return output_result
|
@@ -562,6 +570,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
562
570
|
"""
|
563
571
|
self._infer_input_output_cols(dataset)
|
564
572
|
super()._check_dataset_type(dataset)
|
573
|
+
|
565
574
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
566
575
|
estimator=self._sklearn_object,
|
567
576
|
dataset=dataset,
|
@@ -618,12 +627,41 @@ class SpectralBiclustering(BaseTransformer):
|
|
618
627
|
|
619
628
|
return rv
|
620
629
|
|
621
|
-
def
|
622
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
623
|
-
) -> List[str]:
|
630
|
+
def _align_expected_output(
|
631
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
632
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
633
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
634
|
+
and output dataframe with 1 line.
|
635
|
+
If the method is fit_predict, run 2 lines of data.
|
636
|
+
"""
|
624
637
|
# in case the inferred output column names dimension is different
|
625
638
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
626
|
-
|
639
|
+
|
640
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
641
|
+
# so change the minimum of number of rows to 2
|
642
|
+
num_examples = 2
|
643
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
644
|
+
project=_PROJECT,
|
645
|
+
subproject=_SUBPROJECT,
|
646
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
647
|
+
inspect.currentframe(), SpectralBiclustering.__class__.__name__
|
648
|
+
),
|
649
|
+
api_calls=[Session.call],
|
650
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
651
|
+
)
|
652
|
+
if output_cols_prefix == "fit_predict_":
|
653
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
654
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
655
|
+
num_examples = self._sklearn_object.n_clusters
|
656
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
657
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
658
|
+
num_examples = self._sklearn_object.min_samples
|
659
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
660
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
661
|
+
num_examples = self._sklearn_object.n_neighbors
|
662
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
663
|
+
else:
|
664
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
627
665
|
|
628
666
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
629
667
|
# seen during the fit.
|
@@ -635,12 +673,14 @@ class SpectralBiclustering(BaseTransformer):
|
|
635
673
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
636
674
|
if self.sample_weight_col:
|
637
675
|
output_df_columns_set -= set(self.sample_weight_col)
|
676
|
+
|
638
677
|
# if the dimension of inferred output column names is correct; use it
|
639
678
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
640
|
-
return expected_output_cols_list
|
679
|
+
return expected_output_cols_list, output_df_pd
|
641
680
|
# otherwise, use the sklearn estimator's output
|
642
681
|
else:
|
643
|
-
|
682
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
683
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
644
684
|
|
645
685
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
646
686
|
@telemetry.send_api_usage_telemetry(
|
@@ -686,7 +726,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
686
726
|
drop_input_cols=self._drop_input_cols,
|
687
727
|
expected_output_cols_type="float",
|
688
728
|
)
|
689
|
-
expected_output_cols = self.
|
729
|
+
expected_output_cols, _ = self._align_expected_output(
|
690
730
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
691
731
|
)
|
692
732
|
|
@@ -752,7 +792,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
752
792
|
drop_input_cols=self._drop_input_cols,
|
753
793
|
expected_output_cols_type="float",
|
754
794
|
)
|
755
|
-
expected_output_cols = self.
|
795
|
+
expected_output_cols, _ = self._align_expected_output(
|
756
796
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
757
797
|
)
|
758
798
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -815,7 +855,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
815
855
|
drop_input_cols=self._drop_input_cols,
|
816
856
|
expected_output_cols_type="float",
|
817
857
|
)
|
818
|
-
expected_output_cols = self.
|
858
|
+
expected_output_cols, _ = self._align_expected_output(
|
819
859
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
820
860
|
)
|
821
861
|
|
@@ -880,7 +920,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
880
920
|
drop_input_cols = self._drop_input_cols,
|
881
921
|
expected_output_cols_type="float",
|
882
922
|
)
|
883
|
-
expected_output_cols = self.
|
923
|
+
expected_output_cols, _ = self._align_expected_output(
|
884
924
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
885
925
|
)
|
886
926
|
|