snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -530,12 +527,23 @@ class MeanShift(BaseTransformer):
|
|
530
527
|
autogenerated=self._autogenerated,
|
531
528
|
subproject=_SUBPROJECT,
|
532
529
|
)
|
533
|
-
|
534
|
-
|
535
|
-
expected_output_cols_list=(
|
536
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
537
|
-
),
|
530
|
+
expected_output_cols = (
|
531
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
538
532
|
)
|
533
|
+
if isinstance(dataset, DataFrame):
|
534
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
535
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
536
|
+
)
|
537
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
538
|
+
drop_input_cols=self._drop_input_cols,
|
539
|
+
expected_output_cols_list=expected_output_cols,
|
540
|
+
example_output_pd_df=example_output_pd_df,
|
541
|
+
)
|
542
|
+
else:
|
543
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
544
|
+
drop_input_cols=self._drop_input_cols,
|
545
|
+
expected_output_cols_list=expected_output_cols,
|
546
|
+
)
|
539
547
|
self._sklearn_object = fitted_estimator
|
540
548
|
self._is_fitted = True
|
541
549
|
return output_result
|
@@ -614,12 +622,41 @@ class MeanShift(BaseTransformer):
|
|
614
622
|
|
615
623
|
return rv
|
616
624
|
|
617
|
-
def
|
618
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
619
|
-
) -> List[str]:
|
625
|
+
def _align_expected_output(
|
626
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
627
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
628
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
629
|
+
and output dataframe with 1 line.
|
630
|
+
If the method is fit_predict, run 2 lines of data.
|
631
|
+
"""
|
620
632
|
# in case the inferred output column names dimension is different
|
621
633
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
622
|
-
|
634
|
+
|
635
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
636
|
+
# so change the minimum of number of rows to 2
|
637
|
+
num_examples = 2
|
638
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
639
|
+
project=_PROJECT,
|
640
|
+
subproject=_SUBPROJECT,
|
641
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
642
|
+
inspect.currentframe(), MeanShift.__class__.__name__
|
643
|
+
),
|
644
|
+
api_calls=[Session.call],
|
645
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
646
|
+
)
|
647
|
+
if output_cols_prefix == "fit_predict_":
|
648
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
649
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
650
|
+
num_examples = self._sklearn_object.n_clusters
|
651
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
652
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
653
|
+
num_examples = self._sklearn_object.min_samples
|
654
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
655
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
656
|
+
num_examples = self._sklearn_object.n_neighbors
|
657
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
658
|
+
else:
|
659
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
623
660
|
|
624
661
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
625
662
|
# seen during the fit.
|
@@ -631,12 +668,14 @@ class MeanShift(BaseTransformer):
|
|
631
668
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
632
669
|
if self.sample_weight_col:
|
633
670
|
output_df_columns_set -= set(self.sample_weight_col)
|
671
|
+
|
634
672
|
# if the dimension of inferred output column names is correct; use it
|
635
673
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
636
|
-
return expected_output_cols_list
|
674
|
+
return expected_output_cols_list, output_df_pd
|
637
675
|
# otherwise, use the sklearn estimator's output
|
638
676
|
else:
|
639
|
-
|
677
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
678
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
640
679
|
|
641
680
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
642
681
|
@telemetry.send_api_usage_telemetry(
|
@@ -682,7 +721,7 @@ class MeanShift(BaseTransformer):
|
|
682
721
|
drop_input_cols=self._drop_input_cols,
|
683
722
|
expected_output_cols_type="float",
|
684
723
|
)
|
685
|
-
expected_output_cols = self.
|
724
|
+
expected_output_cols, _ = self._align_expected_output(
|
686
725
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
687
726
|
)
|
688
727
|
|
@@ -748,7 +787,7 @@ class MeanShift(BaseTransformer):
|
|
748
787
|
drop_input_cols=self._drop_input_cols,
|
749
788
|
expected_output_cols_type="float",
|
750
789
|
)
|
751
|
-
expected_output_cols = self.
|
790
|
+
expected_output_cols, _ = self._align_expected_output(
|
752
791
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
753
792
|
)
|
754
793
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -811,7 +850,7 @@ class MeanShift(BaseTransformer):
|
|
811
850
|
drop_input_cols=self._drop_input_cols,
|
812
851
|
expected_output_cols_type="float",
|
813
852
|
)
|
814
|
-
expected_output_cols = self.
|
853
|
+
expected_output_cols, _ = self._align_expected_output(
|
815
854
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
816
855
|
)
|
817
856
|
|
@@ -876,7 +915,7 @@ class MeanShift(BaseTransformer):
|
|
876
915
|
drop_input_cols = self._drop_input_cols,
|
877
916
|
expected_output_cols_type="float",
|
878
917
|
)
|
879
|
-
expected_output_cols = self.
|
918
|
+
expected_output_cols, _ = self._align_expected_output(
|
880
919
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
881
920
|
)
|
882
921
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -582,12 +579,23 @@ class MiniBatchKMeans(BaseTransformer):
|
|
582
579
|
autogenerated=self._autogenerated,
|
583
580
|
subproject=_SUBPROJECT,
|
584
581
|
)
|
585
|
-
|
586
|
-
|
587
|
-
expected_output_cols_list=(
|
588
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
589
|
-
),
|
582
|
+
expected_output_cols = (
|
583
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
590
584
|
)
|
585
|
+
if isinstance(dataset, DataFrame):
|
586
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
587
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
588
|
+
)
|
589
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
590
|
+
drop_input_cols=self._drop_input_cols,
|
591
|
+
expected_output_cols_list=expected_output_cols,
|
592
|
+
example_output_pd_df=example_output_pd_df,
|
593
|
+
)
|
594
|
+
else:
|
595
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
596
|
+
drop_input_cols=self._drop_input_cols,
|
597
|
+
expected_output_cols_list=expected_output_cols,
|
598
|
+
)
|
591
599
|
self._sklearn_object = fitted_estimator
|
592
600
|
self._is_fitted = True
|
593
601
|
return output_result
|
@@ -668,12 +676,41 @@ class MiniBatchKMeans(BaseTransformer):
|
|
668
676
|
|
669
677
|
return rv
|
670
678
|
|
671
|
-
def
|
672
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
673
|
-
) -> List[str]:
|
679
|
+
def _align_expected_output(
|
680
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
681
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
682
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
683
|
+
and output dataframe with 1 line.
|
684
|
+
If the method is fit_predict, run 2 lines of data.
|
685
|
+
"""
|
674
686
|
# in case the inferred output column names dimension is different
|
675
687
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
676
|
-
|
688
|
+
|
689
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
690
|
+
# so change the minimum of number of rows to 2
|
691
|
+
num_examples = 2
|
692
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
693
|
+
project=_PROJECT,
|
694
|
+
subproject=_SUBPROJECT,
|
695
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
696
|
+
inspect.currentframe(), MiniBatchKMeans.__class__.__name__
|
697
|
+
),
|
698
|
+
api_calls=[Session.call],
|
699
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
700
|
+
)
|
701
|
+
if output_cols_prefix == "fit_predict_":
|
702
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
703
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
704
|
+
num_examples = self._sklearn_object.n_clusters
|
705
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
706
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
707
|
+
num_examples = self._sklearn_object.min_samples
|
708
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
709
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
710
|
+
num_examples = self._sklearn_object.n_neighbors
|
711
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
712
|
+
else:
|
713
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
677
714
|
|
678
715
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
679
716
|
# seen during the fit.
|
@@ -685,12 +722,14 @@ class MiniBatchKMeans(BaseTransformer):
|
|
685
722
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
686
723
|
if self.sample_weight_col:
|
687
724
|
output_df_columns_set -= set(self.sample_weight_col)
|
725
|
+
|
688
726
|
# if the dimension of inferred output column names is correct; use it
|
689
727
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
690
|
-
return expected_output_cols_list
|
728
|
+
return expected_output_cols_list, output_df_pd
|
691
729
|
# otherwise, use the sklearn estimator's output
|
692
730
|
else:
|
693
|
-
|
731
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
732
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
694
733
|
|
695
734
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
696
735
|
@telemetry.send_api_usage_telemetry(
|
@@ -736,7 +775,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
736
775
|
drop_input_cols=self._drop_input_cols,
|
737
776
|
expected_output_cols_type="float",
|
738
777
|
)
|
739
|
-
expected_output_cols = self.
|
778
|
+
expected_output_cols, _ = self._align_expected_output(
|
740
779
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
741
780
|
)
|
742
781
|
|
@@ -802,7 +841,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
802
841
|
drop_input_cols=self._drop_input_cols,
|
803
842
|
expected_output_cols_type="float",
|
804
843
|
)
|
805
|
-
expected_output_cols = self.
|
844
|
+
expected_output_cols, _ = self._align_expected_output(
|
806
845
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
807
846
|
)
|
808
847
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -865,7 +904,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
865
904
|
drop_input_cols=self._drop_input_cols,
|
866
905
|
expected_output_cols_type="float",
|
867
906
|
)
|
868
|
-
expected_output_cols = self.
|
907
|
+
expected_output_cols, _ = self._align_expected_output(
|
869
908
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
870
909
|
)
|
871
910
|
|
@@ -930,7 +969,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
930
969
|
drop_input_cols = self._drop_input_cols,
|
931
970
|
expected_output_cols_type="float",
|
932
971
|
)
|
933
|
-
expected_output_cols = self.
|
972
|
+
expected_output_cols, _ = self._align_expected_output(
|
934
973
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
935
974
|
)
|
936
975
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -598,12 +595,23 @@ class OPTICS(BaseTransformer):
|
|
598
595
|
autogenerated=self._autogenerated,
|
599
596
|
subproject=_SUBPROJECT,
|
600
597
|
)
|
601
|
-
|
602
|
-
|
603
|
-
expected_output_cols_list=(
|
604
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
605
|
-
),
|
598
|
+
expected_output_cols = (
|
599
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
606
600
|
)
|
601
|
+
if isinstance(dataset, DataFrame):
|
602
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
603
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
604
|
+
)
|
605
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
606
|
+
drop_input_cols=self._drop_input_cols,
|
607
|
+
expected_output_cols_list=expected_output_cols,
|
608
|
+
example_output_pd_df=example_output_pd_df,
|
609
|
+
)
|
610
|
+
else:
|
611
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
612
|
+
drop_input_cols=self._drop_input_cols,
|
613
|
+
expected_output_cols_list=expected_output_cols,
|
614
|
+
)
|
607
615
|
self._sklearn_object = fitted_estimator
|
608
616
|
self._is_fitted = True
|
609
617
|
return output_result
|
@@ -682,12 +690,41 @@ class OPTICS(BaseTransformer):
|
|
682
690
|
|
683
691
|
return rv
|
684
692
|
|
685
|
-
def
|
686
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
687
|
-
) -> List[str]:
|
693
|
+
def _align_expected_output(
|
694
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
695
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
696
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
697
|
+
and output dataframe with 1 line.
|
698
|
+
If the method is fit_predict, run 2 lines of data.
|
699
|
+
"""
|
688
700
|
# in case the inferred output column names dimension is different
|
689
701
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
690
|
-
|
702
|
+
|
703
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
704
|
+
# so change the minimum of number of rows to 2
|
705
|
+
num_examples = 2
|
706
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
707
|
+
project=_PROJECT,
|
708
|
+
subproject=_SUBPROJECT,
|
709
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
710
|
+
inspect.currentframe(), OPTICS.__class__.__name__
|
711
|
+
),
|
712
|
+
api_calls=[Session.call],
|
713
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
714
|
+
)
|
715
|
+
if output_cols_prefix == "fit_predict_":
|
716
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
717
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
718
|
+
num_examples = self._sklearn_object.n_clusters
|
719
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
720
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
721
|
+
num_examples = self._sklearn_object.min_samples
|
722
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
723
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
724
|
+
num_examples = self._sklearn_object.n_neighbors
|
725
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
726
|
+
else:
|
727
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
691
728
|
|
692
729
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
693
730
|
# seen during the fit.
|
@@ -699,12 +736,14 @@ class OPTICS(BaseTransformer):
|
|
699
736
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
700
737
|
if self.sample_weight_col:
|
701
738
|
output_df_columns_set -= set(self.sample_weight_col)
|
739
|
+
|
702
740
|
# if the dimension of inferred output column names is correct; use it
|
703
741
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
704
|
-
return expected_output_cols_list
|
742
|
+
return expected_output_cols_list, output_df_pd
|
705
743
|
# otherwise, use the sklearn estimator's output
|
706
744
|
else:
|
707
|
-
|
745
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
746
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
708
747
|
|
709
748
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
710
749
|
@telemetry.send_api_usage_telemetry(
|
@@ -750,7 +789,7 @@ class OPTICS(BaseTransformer):
|
|
750
789
|
drop_input_cols=self._drop_input_cols,
|
751
790
|
expected_output_cols_type="float",
|
752
791
|
)
|
753
|
-
expected_output_cols = self.
|
792
|
+
expected_output_cols, _ = self._align_expected_output(
|
754
793
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
755
794
|
)
|
756
795
|
|
@@ -816,7 +855,7 @@ class OPTICS(BaseTransformer):
|
|
816
855
|
drop_input_cols=self._drop_input_cols,
|
817
856
|
expected_output_cols_type="float",
|
818
857
|
)
|
819
|
-
expected_output_cols = self.
|
858
|
+
expected_output_cols, _ = self._align_expected_output(
|
820
859
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
821
860
|
)
|
822
861
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -879,7 +918,7 @@ class OPTICS(BaseTransformer):
|
|
879
918
|
drop_input_cols=self._drop_input_cols,
|
880
919
|
expected_output_cols_type="float",
|
881
920
|
)
|
882
|
-
expected_output_cols = self.
|
921
|
+
expected_output_cols, _ = self._align_expected_output(
|
883
922
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
884
923
|
)
|
885
924
|
|
@@ -944,7 +983,7 @@ class OPTICS(BaseTransformer):
|
|
944
983
|
drop_input_cols = self._drop_input_cols,
|
945
984
|
expected_output_cols_type="float",
|
946
985
|
)
|
947
|
-
expected_output_cols = self.
|
986
|
+
expected_output_cols, _ = self._align_expected_output(
|
948
987
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
949
988
|
)
|
950
989
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -534,12 +531,23 @@ class SpectralBiclustering(BaseTransformer):
|
|
534
531
|
autogenerated=self._autogenerated,
|
535
532
|
subproject=_SUBPROJECT,
|
536
533
|
)
|
537
|
-
|
538
|
-
|
539
|
-
expected_output_cols_list=(
|
540
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
541
|
-
),
|
534
|
+
expected_output_cols = (
|
535
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
542
536
|
)
|
537
|
+
if isinstance(dataset, DataFrame):
|
538
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
539
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
540
|
+
)
|
541
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
542
|
+
drop_input_cols=self._drop_input_cols,
|
543
|
+
expected_output_cols_list=expected_output_cols,
|
544
|
+
example_output_pd_df=example_output_pd_df,
|
545
|
+
)
|
546
|
+
else:
|
547
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
548
|
+
drop_input_cols=self._drop_input_cols,
|
549
|
+
expected_output_cols_list=expected_output_cols,
|
550
|
+
)
|
543
551
|
self._sklearn_object = fitted_estimator
|
544
552
|
self._is_fitted = True
|
545
553
|
return output_result
|
@@ -618,12 +626,41 @@ class SpectralBiclustering(BaseTransformer):
|
|
618
626
|
|
619
627
|
return rv
|
620
628
|
|
621
|
-
def
|
622
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
623
|
-
) -> List[str]:
|
629
|
+
def _align_expected_output(
|
630
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
631
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
632
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
633
|
+
and output dataframe with 1 line.
|
634
|
+
If the method is fit_predict, run 2 lines of data.
|
635
|
+
"""
|
624
636
|
# in case the inferred output column names dimension is different
|
625
637
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
626
|
-
|
638
|
+
|
639
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
640
|
+
# so change the minimum of number of rows to 2
|
641
|
+
num_examples = 2
|
642
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
643
|
+
project=_PROJECT,
|
644
|
+
subproject=_SUBPROJECT,
|
645
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
646
|
+
inspect.currentframe(), SpectralBiclustering.__class__.__name__
|
647
|
+
),
|
648
|
+
api_calls=[Session.call],
|
649
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
650
|
+
)
|
651
|
+
if output_cols_prefix == "fit_predict_":
|
652
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
653
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
654
|
+
num_examples = self._sklearn_object.n_clusters
|
655
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
656
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
657
|
+
num_examples = self._sklearn_object.min_samples
|
658
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
659
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
660
|
+
num_examples = self._sklearn_object.n_neighbors
|
661
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
662
|
+
else:
|
663
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
627
664
|
|
628
665
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
629
666
|
# seen during the fit.
|
@@ -635,12 +672,14 @@ class SpectralBiclustering(BaseTransformer):
|
|
635
672
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
636
673
|
if self.sample_weight_col:
|
637
674
|
output_df_columns_set -= set(self.sample_weight_col)
|
675
|
+
|
638
676
|
# if the dimension of inferred output column names is correct; use it
|
639
677
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
640
|
-
return expected_output_cols_list
|
678
|
+
return expected_output_cols_list, output_df_pd
|
641
679
|
# otherwise, use the sklearn estimator's output
|
642
680
|
else:
|
643
|
-
|
681
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
682
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
644
683
|
|
645
684
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
646
685
|
@telemetry.send_api_usage_telemetry(
|
@@ -686,7 +725,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
686
725
|
drop_input_cols=self._drop_input_cols,
|
687
726
|
expected_output_cols_type="float",
|
688
727
|
)
|
689
|
-
expected_output_cols = self.
|
728
|
+
expected_output_cols, _ = self._align_expected_output(
|
690
729
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
691
730
|
)
|
692
731
|
|
@@ -752,7 +791,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
752
791
|
drop_input_cols=self._drop_input_cols,
|
753
792
|
expected_output_cols_type="float",
|
754
793
|
)
|
755
|
-
expected_output_cols = self.
|
794
|
+
expected_output_cols, _ = self._align_expected_output(
|
756
795
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
757
796
|
)
|
758
797
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -815,7 +854,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
815
854
|
drop_input_cols=self._drop_input_cols,
|
816
855
|
expected_output_cols_type="float",
|
817
856
|
)
|
818
|
-
expected_output_cols = self.
|
857
|
+
expected_output_cols, _ = self._align_expected_output(
|
819
858
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
820
859
|
)
|
821
860
|
|
@@ -880,7 +919,7 @@ class SpectralBiclustering(BaseTransformer):
|
|
880
919
|
drop_input_cols = self._drop_input_cols,
|
881
920
|
expected_output_cols_type="float",
|
882
921
|
)
|
883
|
-
expected_output_cols = self.
|
922
|
+
expected_output_cols, _ = self._align_expected_output(
|
884
923
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
885
924
|
)
|
886
925
|
|