snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -517,12 +514,23 @@ class MultiTaskLasso(BaseTransformer):
|
|
517
514
|
autogenerated=self._autogenerated,
|
518
515
|
subproject=_SUBPROJECT,
|
519
516
|
)
|
520
|
-
|
521
|
-
|
522
|
-
expected_output_cols_list=(
|
523
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
524
|
-
),
|
517
|
+
expected_output_cols = (
|
518
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
525
519
|
)
|
520
|
+
if isinstance(dataset, DataFrame):
|
521
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
522
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
523
|
+
)
|
524
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
525
|
+
drop_input_cols=self._drop_input_cols,
|
526
|
+
expected_output_cols_list=expected_output_cols,
|
527
|
+
example_output_pd_df=example_output_pd_df,
|
528
|
+
)
|
529
|
+
else:
|
530
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
531
|
+
drop_input_cols=self._drop_input_cols,
|
532
|
+
expected_output_cols_list=expected_output_cols,
|
533
|
+
)
|
526
534
|
self._sklearn_object = fitted_estimator
|
527
535
|
self._is_fitted = True
|
528
536
|
return output_result
|
@@ -601,12 +609,41 @@ class MultiTaskLasso(BaseTransformer):
|
|
601
609
|
|
602
610
|
return rv
|
603
611
|
|
604
|
-
def
|
605
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
606
|
-
) -> List[str]:
|
612
|
+
def _align_expected_output(
|
613
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
614
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
615
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
616
|
+
and output dataframe with 1 line.
|
617
|
+
If the method is fit_predict, run 2 lines of data.
|
618
|
+
"""
|
607
619
|
# in case the inferred output column names dimension is different
|
608
620
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
609
|
-
|
621
|
+
|
622
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
623
|
+
# so change the minimum of number of rows to 2
|
624
|
+
num_examples = 2
|
625
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
626
|
+
project=_PROJECT,
|
627
|
+
subproject=_SUBPROJECT,
|
628
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
629
|
+
inspect.currentframe(), MultiTaskLasso.__class__.__name__
|
630
|
+
),
|
631
|
+
api_calls=[Session.call],
|
632
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
633
|
+
)
|
634
|
+
if output_cols_prefix == "fit_predict_":
|
635
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
636
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
637
|
+
num_examples = self._sklearn_object.n_clusters
|
638
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
639
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
640
|
+
num_examples = self._sklearn_object.min_samples
|
641
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
642
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
643
|
+
num_examples = self._sklearn_object.n_neighbors
|
644
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
645
|
+
else:
|
646
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
610
647
|
|
611
648
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
612
649
|
# seen during the fit.
|
@@ -618,12 +655,14 @@ class MultiTaskLasso(BaseTransformer):
|
|
618
655
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
619
656
|
if self.sample_weight_col:
|
620
657
|
output_df_columns_set -= set(self.sample_weight_col)
|
658
|
+
|
621
659
|
# if the dimension of inferred output column names is correct; use it
|
622
660
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
623
|
-
return expected_output_cols_list
|
661
|
+
return expected_output_cols_list, output_df_pd
|
624
662
|
# otherwise, use the sklearn estimator's output
|
625
663
|
else:
|
626
|
-
|
664
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
665
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
627
666
|
|
628
667
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
629
668
|
@telemetry.send_api_usage_telemetry(
|
@@ -669,7 +708,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
669
708
|
drop_input_cols=self._drop_input_cols,
|
670
709
|
expected_output_cols_type="float",
|
671
710
|
)
|
672
|
-
expected_output_cols = self.
|
711
|
+
expected_output_cols, _ = self._align_expected_output(
|
673
712
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
674
713
|
)
|
675
714
|
|
@@ -735,7 +774,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
735
774
|
drop_input_cols=self._drop_input_cols,
|
736
775
|
expected_output_cols_type="float",
|
737
776
|
)
|
738
|
-
expected_output_cols = self.
|
777
|
+
expected_output_cols, _ = self._align_expected_output(
|
739
778
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
740
779
|
)
|
741
780
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -798,7 +837,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
798
837
|
drop_input_cols=self._drop_input_cols,
|
799
838
|
expected_output_cols_type="float",
|
800
839
|
)
|
801
|
-
expected_output_cols = self.
|
840
|
+
expected_output_cols, _ = self._align_expected_output(
|
802
841
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
803
842
|
)
|
804
843
|
|
@@ -863,7 +902,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
863
902
|
drop_input_cols = self._drop_input_cols,
|
864
903
|
expected_output_cols_type="float",
|
865
904
|
)
|
866
|
-
expected_output_cols = self.
|
905
|
+
expected_output_cols, _ = self._align_expected_output(
|
867
906
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
868
907
|
)
|
869
908
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -552,12 +549,23 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
552
549
|
autogenerated=self._autogenerated,
|
553
550
|
subproject=_SUBPROJECT,
|
554
551
|
)
|
555
|
-
|
556
|
-
|
557
|
-
expected_output_cols_list=(
|
558
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
559
|
-
),
|
552
|
+
expected_output_cols = (
|
553
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
560
554
|
)
|
555
|
+
if isinstance(dataset, DataFrame):
|
556
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
557
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
558
|
+
)
|
559
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
560
|
+
drop_input_cols=self._drop_input_cols,
|
561
|
+
expected_output_cols_list=expected_output_cols,
|
562
|
+
example_output_pd_df=example_output_pd_df,
|
563
|
+
)
|
564
|
+
else:
|
565
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
566
|
+
drop_input_cols=self._drop_input_cols,
|
567
|
+
expected_output_cols_list=expected_output_cols,
|
568
|
+
)
|
561
569
|
self._sklearn_object = fitted_estimator
|
562
570
|
self._is_fitted = True
|
563
571
|
return output_result
|
@@ -636,12 +644,41 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
636
644
|
|
637
645
|
return rv
|
638
646
|
|
639
|
-
def
|
640
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
641
|
-
) -> List[str]:
|
647
|
+
def _align_expected_output(
|
648
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
649
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
650
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
651
|
+
and output dataframe with 1 line.
|
652
|
+
If the method is fit_predict, run 2 lines of data.
|
653
|
+
"""
|
642
654
|
# in case the inferred output column names dimension is different
|
643
655
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
644
|
-
|
656
|
+
|
657
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
658
|
+
# so change the minimum of number of rows to 2
|
659
|
+
num_examples = 2
|
660
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
661
|
+
project=_PROJECT,
|
662
|
+
subproject=_SUBPROJECT,
|
663
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
664
|
+
inspect.currentframe(), MultiTaskLassoCV.__class__.__name__
|
665
|
+
),
|
666
|
+
api_calls=[Session.call],
|
667
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
668
|
+
)
|
669
|
+
if output_cols_prefix == "fit_predict_":
|
670
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
671
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
672
|
+
num_examples = self._sklearn_object.n_clusters
|
673
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
674
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
675
|
+
num_examples = self._sklearn_object.min_samples
|
676
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
677
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
678
|
+
num_examples = self._sklearn_object.n_neighbors
|
679
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
680
|
+
else:
|
681
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
645
682
|
|
646
683
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
647
684
|
# seen during the fit.
|
@@ -653,12 +690,14 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
653
690
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
654
691
|
if self.sample_weight_col:
|
655
692
|
output_df_columns_set -= set(self.sample_weight_col)
|
693
|
+
|
656
694
|
# if the dimension of inferred output column names is correct; use it
|
657
695
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
658
|
-
return expected_output_cols_list
|
696
|
+
return expected_output_cols_list, output_df_pd
|
659
697
|
# otherwise, use the sklearn estimator's output
|
660
698
|
else:
|
661
|
-
|
699
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
700
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
662
701
|
|
663
702
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
664
703
|
@telemetry.send_api_usage_telemetry(
|
@@ -704,7 +743,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
704
743
|
drop_input_cols=self._drop_input_cols,
|
705
744
|
expected_output_cols_type="float",
|
706
745
|
)
|
707
|
-
expected_output_cols = self.
|
746
|
+
expected_output_cols, _ = self._align_expected_output(
|
708
747
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
709
748
|
)
|
710
749
|
|
@@ -770,7 +809,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
770
809
|
drop_input_cols=self._drop_input_cols,
|
771
810
|
expected_output_cols_type="float",
|
772
811
|
)
|
773
|
-
expected_output_cols = self.
|
812
|
+
expected_output_cols, _ = self._align_expected_output(
|
774
813
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
775
814
|
)
|
776
815
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -833,7 +872,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
833
872
|
drop_input_cols=self._drop_input_cols,
|
834
873
|
expected_output_cols_type="float",
|
835
874
|
)
|
836
|
-
expected_output_cols = self.
|
875
|
+
expected_output_cols, _ = self._align_expected_output(
|
837
876
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
838
877
|
)
|
839
878
|
|
@@ -898,7 +937,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
898
937
|
drop_input_cols = self._drop_input_cols,
|
899
938
|
expected_output_cols_type="float",
|
900
939
|
)
|
901
|
-
expected_output_cols = self.
|
940
|
+
expected_output_cols, _ = self._align_expected_output(
|
902
941
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
903
942
|
)
|
904
943
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -500,12 +497,23 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
500
497
|
autogenerated=self._autogenerated,
|
501
498
|
subproject=_SUBPROJECT,
|
502
499
|
)
|
503
|
-
|
504
|
-
|
505
|
-
expected_output_cols_list=(
|
506
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
507
|
-
),
|
500
|
+
expected_output_cols = (
|
501
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
508
502
|
)
|
503
|
+
if isinstance(dataset, DataFrame):
|
504
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
505
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
506
|
+
)
|
507
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
508
|
+
drop_input_cols=self._drop_input_cols,
|
509
|
+
expected_output_cols_list=expected_output_cols,
|
510
|
+
example_output_pd_df=example_output_pd_df,
|
511
|
+
)
|
512
|
+
else:
|
513
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
514
|
+
drop_input_cols=self._drop_input_cols,
|
515
|
+
expected_output_cols_list=expected_output_cols,
|
516
|
+
)
|
509
517
|
self._sklearn_object = fitted_estimator
|
510
518
|
self._is_fitted = True
|
511
519
|
return output_result
|
@@ -584,12 +592,41 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
584
592
|
|
585
593
|
return rv
|
586
594
|
|
587
|
-
def
|
588
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
589
|
-
) -> List[str]:
|
595
|
+
def _align_expected_output(
|
596
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
597
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
598
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
599
|
+
and output dataframe with 1 line.
|
600
|
+
If the method is fit_predict, run 2 lines of data.
|
601
|
+
"""
|
590
602
|
# in case the inferred output column names dimension is different
|
591
603
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
592
|
-
|
604
|
+
|
605
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
606
|
+
# so change the minimum of number of rows to 2
|
607
|
+
num_examples = 2
|
608
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
609
|
+
project=_PROJECT,
|
610
|
+
subproject=_SUBPROJECT,
|
611
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
612
|
+
inspect.currentframe(), OrthogonalMatchingPursuit.__class__.__name__
|
613
|
+
),
|
614
|
+
api_calls=[Session.call],
|
615
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
616
|
+
)
|
617
|
+
if output_cols_prefix == "fit_predict_":
|
618
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
619
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
620
|
+
num_examples = self._sklearn_object.n_clusters
|
621
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
622
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
623
|
+
num_examples = self._sklearn_object.min_samples
|
624
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
625
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
626
|
+
num_examples = self._sklearn_object.n_neighbors
|
627
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
628
|
+
else:
|
629
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
593
630
|
|
594
631
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
595
632
|
# seen during the fit.
|
@@ -601,12 +638,14 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
601
638
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
602
639
|
if self.sample_weight_col:
|
603
640
|
output_df_columns_set -= set(self.sample_weight_col)
|
641
|
+
|
604
642
|
# if the dimension of inferred output column names is correct; use it
|
605
643
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
606
|
-
return expected_output_cols_list
|
644
|
+
return expected_output_cols_list, output_df_pd
|
607
645
|
# otherwise, use the sklearn estimator's output
|
608
646
|
else:
|
609
|
-
|
647
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
648
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
610
649
|
|
611
650
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
612
651
|
@telemetry.send_api_usage_telemetry(
|
@@ -652,7 +691,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
652
691
|
drop_input_cols=self._drop_input_cols,
|
653
692
|
expected_output_cols_type="float",
|
654
693
|
)
|
655
|
-
expected_output_cols = self.
|
694
|
+
expected_output_cols, _ = self._align_expected_output(
|
656
695
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
657
696
|
)
|
658
697
|
|
@@ -718,7 +757,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
718
757
|
drop_input_cols=self._drop_input_cols,
|
719
758
|
expected_output_cols_type="float",
|
720
759
|
)
|
721
|
-
expected_output_cols = self.
|
760
|
+
expected_output_cols, _ = self._align_expected_output(
|
722
761
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
723
762
|
)
|
724
763
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -781,7 +820,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
781
820
|
drop_input_cols=self._drop_input_cols,
|
782
821
|
expected_output_cols_type="float",
|
783
822
|
)
|
784
|
-
expected_output_cols = self.
|
823
|
+
expected_output_cols, _ = self._align_expected_output(
|
785
824
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
786
825
|
)
|
787
826
|
|
@@ -846,7 +885,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
846
885
|
drop_input_cols = self._drop_input_cols,
|
847
886
|
expected_output_cols_type="float",
|
848
887
|
)
|
849
|
-
expected_output_cols = self.
|
888
|
+
expected_output_cols, _ = self._align_expected_output(
|
850
889
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
851
890
|
)
|
852
891
|
|
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -574,12 +571,23 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
574
571
|
autogenerated=self._autogenerated,
|
575
572
|
subproject=_SUBPROJECT,
|
576
573
|
)
|
577
|
-
|
578
|
-
|
579
|
-
expected_output_cols_list=(
|
580
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
581
|
-
),
|
574
|
+
expected_output_cols = (
|
575
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
582
576
|
)
|
577
|
+
if isinstance(dataset, DataFrame):
|
578
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
579
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
580
|
+
)
|
581
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
582
|
+
drop_input_cols=self._drop_input_cols,
|
583
|
+
expected_output_cols_list=expected_output_cols,
|
584
|
+
example_output_pd_df=example_output_pd_df,
|
585
|
+
)
|
586
|
+
else:
|
587
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
588
|
+
drop_input_cols=self._drop_input_cols,
|
589
|
+
expected_output_cols_list=expected_output_cols,
|
590
|
+
)
|
583
591
|
self._sklearn_object = fitted_estimator
|
584
592
|
self._is_fitted = True
|
585
593
|
return output_result
|
@@ -658,12 +666,41 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
658
666
|
|
659
667
|
return rv
|
660
668
|
|
661
|
-
def
|
662
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
663
|
-
) -> List[str]:
|
669
|
+
def _align_expected_output(
|
670
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
671
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
672
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
673
|
+
and output dataframe with 1 line.
|
674
|
+
If the method is fit_predict, run 2 lines of data.
|
675
|
+
"""
|
664
676
|
# in case the inferred output column names dimension is different
|
665
677
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
666
|
-
|
678
|
+
|
679
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
680
|
+
# so change the minimum of number of rows to 2
|
681
|
+
num_examples = 2
|
682
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
683
|
+
project=_PROJECT,
|
684
|
+
subproject=_SUBPROJECT,
|
685
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
686
|
+
inspect.currentframe(), PassiveAggressiveClassifier.__class__.__name__
|
687
|
+
),
|
688
|
+
api_calls=[Session.call],
|
689
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
690
|
+
)
|
691
|
+
if output_cols_prefix == "fit_predict_":
|
692
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
693
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
694
|
+
num_examples = self._sklearn_object.n_clusters
|
695
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
696
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
697
|
+
num_examples = self._sklearn_object.min_samples
|
698
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
699
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
700
|
+
num_examples = self._sklearn_object.n_neighbors
|
701
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
702
|
+
else:
|
703
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
667
704
|
|
668
705
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
669
706
|
# seen during the fit.
|
@@ -675,12 +712,14 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
675
712
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
676
713
|
if self.sample_weight_col:
|
677
714
|
output_df_columns_set -= set(self.sample_weight_col)
|
715
|
+
|
678
716
|
# if the dimension of inferred output column names is correct; use it
|
679
717
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
680
|
-
return expected_output_cols_list
|
718
|
+
return expected_output_cols_list, output_df_pd
|
681
719
|
# otherwise, use the sklearn estimator's output
|
682
720
|
else:
|
683
|
-
|
721
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
722
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
684
723
|
|
685
724
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
686
725
|
@telemetry.send_api_usage_telemetry(
|
@@ -726,7 +765,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
726
765
|
drop_input_cols=self._drop_input_cols,
|
727
766
|
expected_output_cols_type="float",
|
728
767
|
)
|
729
|
-
expected_output_cols = self.
|
768
|
+
expected_output_cols, _ = self._align_expected_output(
|
730
769
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
731
770
|
)
|
732
771
|
|
@@ -792,7 +831,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
792
831
|
drop_input_cols=self._drop_input_cols,
|
793
832
|
expected_output_cols_type="float",
|
794
833
|
)
|
795
|
-
expected_output_cols = self.
|
834
|
+
expected_output_cols, _ = self._align_expected_output(
|
796
835
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
797
836
|
)
|
798
837
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -857,7 +896,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
857
896
|
drop_input_cols=self._drop_input_cols,
|
858
897
|
expected_output_cols_type="float",
|
859
898
|
)
|
860
|
-
expected_output_cols = self.
|
899
|
+
expected_output_cols, _ = self._align_expected_output(
|
861
900
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
862
901
|
)
|
863
902
|
|
@@ -922,7 +961,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
922
961
|
drop_input_cols = self._drop_input_cols,
|
923
962
|
expected_output_cols_type="float",
|
924
963
|
)
|
925
|
-
expected_output_cols = self.
|
964
|
+
expected_output_cols, _ = self._align_expected_output(
|
926
965
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
927
966
|
)
|
928
967
|
|