snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
- snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
- snowflake/ml/modeling/cluster/birch.py +94 -124
- snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
- snowflake/ml/modeling/cluster/dbscan.py +94 -124
- snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
- snowflake/ml/modeling/cluster/k_means.py +93 -124
- snowflake/ml/modeling/cluster/mean_shift.py +94 -124
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
- snowflake/ml/modeling/cluster/optics.py +94 -124
- snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
- snowflake/ml/modeling/compose/column_transformer.py +94 -124
- snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
- snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
- snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
- snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
- snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
- snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
- snowflake/ml/modeling/covariance/oas.py +80 -110
- snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
- snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
- snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
- snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
- snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/pca.py +94 -124
- snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
- snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
- snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
- snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
- snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
- snowflake/ml/modeling/framework/base.py +2 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
- snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
- snowflake/ml/modeling/impute/knn_imputer.py +94 -124
- snowflake/ml/modeling/impute/missing_indicator.py +94 -124
- snowflake/ml/modeling/impute/simple_imputer.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
- snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
- snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/lars.py +96 -124
- snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
- snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
- snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
- snowflake/ml/modeling/linear_model/perceptron.py +95 -124
- snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ridge.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
- snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
- snowflake/ml/modeling/manifold/isomap.py +94 -124
- snowflake/ml/modeling/manifold/mds.py +94 -124
- snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
- snowflake/ml/modeling/manifold/tsne.py +94 -124
- snowflake/ml/modeling/metrics/classification.py +187 -52
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
- snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
- snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
- snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
- snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
- snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
- snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
- snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
- snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
- snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
- snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
- snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
- snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
- snowflake/ml/modeling/svm/linear_svc.py +96 -124
- snowflake/ml/modeling/svm/linear_svr.py +96 -124
- snowflake/ml/modeling/svm/nu_svc.py +96 -124
- snowflake/ml/modeling/svm/nu_svr.py +96 -124
- snowflake/ml/modeling/svm/svc.py +96 -124
- snowflake/ml/modeling/svm/svr.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
- snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,554 @@
|
|
1
|
+
import importlib
|
2
|
+
import inspect
|
3
|
+
import io
|
4
|
+
import os
|
5
|
+
import posixpath
|
6
|
+
import sys
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
8
|
+
|
9
|
+
import cloudpickle as cp
|
10
|
+
import numpy as np
|
11
|
+
from scipy.stats import rankdata
|
12
|
+
from sklearn import model_selection
|
13
|
+
|
14
|
+
from snowflake.ml._internal import telemetry
|
15
|
+
from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
|
16
|
+
from snowflake.ml._internal.utils.temp_file_utils import (
|
17
|
+
cleanup_temp_files,
|
18
|
+
get_temp_file_path,
|
19
|
+
)
|
20
|
+
from snowflake.ml.modeling._internal.model_specifications import (
|
21
|
+
ModelSpecificationsBuilder,
|
22
|
+
)
|
23
|
+
from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
|
24
|
+
from snowflake.snowpark import DataFrame, Session, functions as F
|
25
|
+
from snowflake.snowpark._internal.utils import (
|
26
|
+
TempObjectType,
|
27
|
+
random_name_for_temp_object,
|
28
|
+
)
|
29
|
+
from snowflake.snowpark.functions import col, sproc, udtf
|
30
|
+
from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
|
31
|
+
|
32
|
+
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
33
|
+
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
34
|
+
|
35
|
+
_PROJECT = "ModelDevelopment"
|
36
|
+
DEFAULT_UDTF_NJOBS = 3
|
37
|
+
|
38
|
+
|
39
|
+
class DistributedHPOTrainer(SnowparkModelTrainer):
|
40
|
+
"""
|
41
|
+
A class for performing distributed hyperparameter optimization (HPO) using Snowpark.
|
42
|
+
|
43
|
+
This class inherits from SnowparkModelTrainer and extends its functionality
|
44
|
+
to support distributed HPO for machine learning models. It enables optimization
|
45
|
+
of hyperparameters by distributing the tasks across the warehouse using Snowpark.
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
estimator: object,
|
51
|
+
dataset: DataFrame,
|
52
|
+
session: Session,
|
53
|
+
input_cols: List[str],
|
54
|
+
label_cols: Optional[List[str]],
|
55
|
+
sample_weight_col: Optional[str],
|
56
|
+
autogenerated: bool = False,
|
57
|
+
subproject: str = "",
|
58
|
+
) -> None:
|
59
|
+
"""
|
60
|
+
Initializes the DistributedHPOTrainer with a model, a Snowpark DataFrame, feature, and label column names, etc.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
estimator: SKLearn compatible estimator or transformer object.
|
64
|
+
dataset: The dataset used for training the model.
|
65
|
+
session: Snowflake session object to be used for training.
|
66
|
+
input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training.
|
67
|
+
label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn.
|
68
|
+
sample_weight_col: The column name representing the weight of training examples.
|
69
|
+
autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not.
|
70
|
+
subproject: subproject name to be used in telemetry.
|
71
|
+
"""
|
72
|
+
super().__init__(
|
73
|
+
estimator=estimator,
|
74
|
+
dataset=dataset,
|
75
|
+
session=session,
|
76
|
+
input_cols=input_cols,
|
77
|
+
label_cols=label_cols,
|
78
|
+
sample_weight_col=sample_weight_col,
|
79
|
+
autogenerated=autogenerated,
|
80
|
+
subproject=subproject,
|
81
|
+
)
|
82
|
+
|
83
|
+
# TODO(snandamuri): Copied this code as it is from the snowpark_handler.
|
84
|
+
# Update it to improve the readability.
|
85
|
+
def fit_search_snowpark(
|
86
|
+
self,
|
87
|
+
param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
|
88
|
+
dataset: DataFrame,
|
89
|
+
session: Session,
|
90
|
+
estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
91
|
+
dependencies: List[str],
|
92
|
+
udf_imports: List[str],
|
93
|
+
input_cols: List[str],
|
94
|
+
label_cols: Optional[List[str]],
|
95
|
+
sample_weight_col: Optional[str],
|
96
|
+
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
97
|
+
from itertools import product
|
98
|
+
|
99
|
+
import cachetools
|
100
|
+
from sklearn.base import clone, is_classifier
|
101
|
+
from sklearn.calibration import check_cv
|
102
|
+
|
103
|
+
# Create one stage for data and for estimators.
|
104
|
+
temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
|
105
|
+
temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
|
106
|
+
session.sql(temp_stage_creation_query).collect()
|
107
|
+
|
108
|
+
# Stage data.
|
109
|
+
dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
|
110
|
+
remote_file_path = f"{temp_stage_name}/{temp_stage_name}.parquet"
|
111
|
+
dataset.write.copy_into_location( # type:ignore[call-overload]
|
112
|
+
remote_file_path, file_format_type="parquet", header=True, overwrite=True
|
113
|
+
)
|
114
|
+
imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}").collect()]
|
115
|
+
|
116
|
+
# Store GridSearchCV's refit variable. If user set it as False, we don't need to refit it again
|
117
|
+
original_refit = estimator.refit
|
118
|
+
|
119
|
+
# Create a temp file and dump the estimator to that file.
|
120
|
+
estimator_file_name = get_temp_file_path()
|
121
|
+
params_to_evaluate = []
|
122
|
+
for param_to_eval in list(param_grid):
|
123
|
+
for k, v in param_to_eval.items():
|
124
|
+
param_to_eval[k] = [v]
|
125
|
+
params_to_evaluate.append([param_to_eval])
|
126
|
+
|
127
|
+
with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
|
128
|
+
# Set GridSearchCV refit as False and fit it again after retrieving the best param
|
129
|
+
estimator.refit = False
|
130
|
+
cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
|
131
|
+
stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name))
|
132
|
+
sproc_statement_params = telemetry.get_function_usage_statement_params(
|
133
|
+
project=_PROJECT,
|
134
|
+
subproject=self._subproject,
|
135
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
136
|
+
inspect.currentframe(), self.__class__.__name__
|
137
|
+
),
|
138
|
+
api_calls=[sproc],
|
139
|
+
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
140
|
+
)
|
141
|
+
udtf_statement_params = telemetry.get_function_usage_statement_params(
|
142
|
+
project=_PROJECT,
|
143
|
+
subproject=self._subproject,
|
144
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
145
|
+
inspect.currentframe(), self.__class__.__name__
|
146
|
+
),
|
147
|
+
api_calls=[udtf],
|
148
|
+
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
149
|
+
)
|
150
|
+
|
151
|
+
# Put locally serialized estimator on stage.
|
152
|
+
put_result = session.file.put(
|
153
|
+
estimator_file_name,
|
154
|
+
temp_stage_name,
|
155
|
+
auto_compress=False,
|
156
|
+
overwrite=True,
|
157
|
+
)
|
158
|
+
estimator_location = put_result[0].target
|
159
|
+
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
160
|
+
|
161
|
+
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
162
|
+
random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
|
163
|
+
|
164
|
+
required_deps = dependencies + [
|
165
|
+
"snowflake-snowpark-python<2",
|
166
|
+
"fastparquet<2023.11",
|
167
|
+
"pyarrow<14",
|
168
|
+
"cachetools<5",
|
169
|
+
]
|
170
|
+
|
171
|
+
@sproc( # type: ignore[misc]
|
172
|
+
is_permanent=False,
|
173
|
+
name=search_sproc_name,
|
174
|
+
packages=required_deps, # type: ignore[arg-type]
|
175
|
+
replace=True,
|
176
|
+
session=session,
|
177
|
+
anonymous=True,
|
178
|
+
imports=imports, # type: ignore[arg-type]
|
179
|
+
statement_params=sproc_statement_params,
|
180
|
+
)
|
181
|
+
def _distributed_search(
|
182
|
+
session: Session,
|
183
|
+
imports: List[str],
|
184
|
+
stage_estimator_file_name: str,
|
185
|
+
input_cols: List[str],
|
186
|
+
label_cols: Optional[List[str]],
|
187
|
+
) -> str:
|
188
|
+
import os
|
189
|
+
import time
|
190
|
+
from typing import Iterator
|
191
|
+
|
192
|
+
import cloudpickle as cp
|
193
|
+
import pandas as pd
|
194
|
+
import pyarrow.parquet as pq
|
195
|
+
from sklearn.metrics import check_scoring
|
196
|
+
from sklearn.metrics._scorer import _check_multimetric_scoring
|
197
|
+
|
198
|
+
for import_name in udf_imports:
|
199
|
+
importlib.import_module(import_name)
|
200
|
+
|
201
|
+
data_files = [
|
202
|
+
filename
|
203
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
204
|
+
if filename.startswith(temp_stage_name)
|
205
|
+
]
|
206
|
+
partial_df = [
|
207
|
+
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
208
|
+
for file_name in data_files
|
209
|
+
]
|
210
|
+
df = pd.concat(partial_df, ignore_index=True)
|
211
|
+
df.columns = [identifier.get_inferred_name(col) for col in df.columns]
|
212
|
+
|
213
|
+
X = df[input_cols]
|
214
|
+
y = df[label_cols].squeeze() if label_cols else None
|
215
|
+
|
216
|
+
local_estimator_file_name = get_temp_file_path()
|
217
|
+
session.file.get(stage_estimator_file_name, local_estimator_file_name)
|
218
|
+
|
219
|
+
local_estimator_file_path = os.path.join(
|
220
|
+
local_estimator_file_name, os.listdir(local_estimator_file_name)[0]
|
221
|
+
)
|
222
|
+
with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
|
223
|
+
estimator = cp.load(local_estimator_file_obj)["estimator"]
|
224
|
+
|
225
|
+
cv_orig = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
|
226
|
+
indices = [test for _, test in cv_orig.split(X, y)]
|
227
|
+
local_indices_file_name = get_temp_file_path()
|
228
|
+
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
229
|
+
cp.dump(indices, local_indices_file_obj)
|
230
|
+
|
231
|
+
# Put locally serialized indices on stage.
|
232
|
+
put_result = session.file.put(
|
233
|
+
local_indices_file_name,
|
234
|
+
temp_stage_name,
|
235
|
+
auto_compress=False,
|
236
|
+
overwrite=True,
|
237
|
+
)
|
238
|
+
indices_location = put_result[0].target
|
239
|
+
imports.append(f"@{temp_stage_name}/{indices_location}")
|
240
|
+
indices_len = len(indices)
|
241
|
+
|
242
|
+
assert estimator is not None
|
243
|
+
|
244
|
+
@cachetools.cached(cache={})
|
245
|
+
def _load_data_into_udf() -> Tuple[
|
246
|
+
Dict[str, pd.DataFrame],
|
247
|
+
Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
248
|
+
pd.DataFrame,
|
249
|
+
int,
|
250
|
+
List[Dict[str, Any]],
|
251
|
+
]:
|
252
|
+
import pyarrow.parquet as pq
|
253
|
+
|
254
|
+
data_files = [
|
255
|
+
filename
|
256
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
257
|
+
if filename.startswith(temp_stage_name)
|
258
|
+
]
|
259
|
+
partial_df = [
|
260
|
+
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
261
|
+
for file_name in data_files
|
262
|
+
]
|
263
|
+
df = pd.concat(partial_df, ignore_index=True)
|
264
|
+
df.columns = [identifier.get_inferred_name(col) for col in df.columns]
|
265
|
+
|
266
|
+
# load estimator
|
267
|
+
local_estimator_file_path = os.path.join(
|
268
|
+
sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
|
269
|
+
)
|
270
|
+
with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
|
271
|
+
estimator_objects = cp.load(local_estimator_file_obj)
|
272
|
+
estimator = estimator_objects["estimator"]
|
273
|
+
params_to_evaluate = estimator_objects["param_grid"]
|
274
|
+
|
275
|
+
# load indices
|
276
|
+
local_indices_file_path = os.path.join(
|
277
|
+
sys._xoptions["snowflake_import_directory"], f"{indices_location}"
|
278
|
+
)
|
279
|
+
with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
|
280
|
+
indices = cp.load(local_indices_file_obj)
|
281
|
+
|
282
|
+
argspec = inspect.getfullargspec(estimator.fit)
|
283
|
+
args = {"X": df[input_cols]}
|
284
|
+
|
285
|
+
if label_cols:
|
286
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
287
|
+
args[label_arg_name] = df[label_cols].squeeze()
|
288
|
+
|
289
|
+
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
290
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
291
|
+
return args, estimator, indices, len(df), params_to_evaluate
|
292
|
+
|
293
|
+
class SearchCV:
|
294
|
+
def __init__(self) -> None:
|
295
|
+
args, estimator, indices, data_length, params_to_evaluate = _load_data_into_udf()
|
296
|
+
self.args = args
|
297
|
+
self.estimator = estimator
|
298
|
+
self.indices = indices
|
299
|
+
self.data_length = data_length
|
300
|
+
self.params_to_evaluate = params_to_evaluate
|
301
|
+
|
302
|
+
def process(self, params_idx: int, idx: int) -> Iterator[Tuple[str]]:
|
303
|
+
if hasattr(estimator, "param_grid"):
|
304
|
+
self.estimator.param_grid = self.params_to_evaluate[params_idx]
|
305
|
+
else:
|
306
|
+
self.estimator.param_distributions = self.params_to_evaluate[params_idx]
|
307
|
+
full_indices = np.array([i for i in range(self.data_length)])
|
308
|
+
test_indice = self.indices[idx]
|
309
|
+
train_indice = np.setdiff1d(full_indices, test_indice)
|
310
|
+
self.estimator.cv = [(train_indice, test_indice)]
|
311
|
+
self.estimator.fit(**self.args)
|
312
|
+
binary_cv_results = None
|
313
|
+
with io.BytesIO() as f:
|
314
|
+
cp.dump(self.estimator.cv_results_, f)
|
315
|
+
f.seek(0)
|
316
|
+
binary_cv_results = f.getvalue().hex()
|
317
|
+
yield (binary_cv_results,)
|
318
|
+
|
319
|
+
def end_partition(self) -> None:
|
320
|
+
...
|
321
|
+
|
322
|
+
session.udtf.register(
|
323
|
+
SearchCV,
|
324
|
+
output_schema=StructType([StructField("CV_RESULTS", StringType())]),
|
325
|
+
input_types=[IntegerType(), IntegerType()],
|
326
|
+
name=random_udtf_name,
|
327
|
+
packages=required_deps, # type: ignore[arg-type]
|
328
|
+
replace=True,
|
329
|
+
is_permanent=False,
|
330
|
+
imports=imports, # type: ignore[arg-type]
|
331
|
+
statement_params=udtf_statement_params,
|
332
|
+
)
|
333
|
+
|
334
|
+
HP_TUNING = F.table_function(random_udtf_name)
|
335
|
+
|
336
|
+
idx_length = int(indices_len)
|
337
|
+
params_length = len(param_grid)
|
338
|
+
idxs = [i for i in range(idx_length)]
|
339
|
+
param_indices, training_indices = [], []
|
340
|
+
for param_idx, cv_idx in product([param_index for param_index in range(params_length)], idxs):
|
341
|
+
param_indices.append(param_idx)
|
342
|
+
training_indices.append(cv_idx)
|
343
|
+
|
344
|
+
pd_df = pd.DataFrame(
|
345
|
+
{
|
346
|
+
"PARAMS": param_indices,
|
347
|
+
"TRAIN_IND": training_indices,
|
348
|
+
"PARAM_INDEX": [i for i in range(idx_length * params_length)],
|
349
|
+
}
|
350
|
+
)
|
351
|
+
df = session.create_dataframe(pd_df)
|
352
|
+
results = df.select(
|
353
|
+
F.cast(df["PARAM_INDEX"], IntegerType()).as_("PARAM_INDEX"),
|
354
|
+
(HP_TUNING(df["PARAMS"], df["TRAIN_IND"]).over(partition_by=df["PARAM_INDEX"])),
|
355
|
+
)
|
356
|
+
|
357
|
+
# cv_result maintains the original order
|
358
|
+
multimetric = False
|
359
|
+
cv_results_ = dict()
|
360
|
+
scorers = set()
|
361
|
+
for i, val in enumerate(results.select("CV_RESULTS").sort(col("PARAM_INDEX")).collect()):
|
362
|
+
# retrieved string had one more double quote in the front and end of the string.
|
363
|
+
# use [1:-1] to remove the extra double quotes
|
364
|
+
hex_str = bytes.fromhex(val[0])
|
365
|
+
with io.BytesIO(hex_str) as f_reload:
|
366
|
+
each_cv_result = cp.load(f_reload)
|
367
|
+
for k, v in each_cv_result.items():
|
368
|
+
cur_cv = i % idx_length
|
369
|
+
key = k
|
370
|
+
if "split0_test_" in k:
|
371
|
+
# For multi-metric evaluation, the scores for all the scorers are available in the
|
372
|
+
# cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
|
373
|
+
# instead of '_score'.
|
374
|
+
scorers.add(k[len("split0_test_") :])
|
375
|
+
key = k.replace("split0_test", f"split{cur_cv}_test")
|
376
|
+
elif k.startswith("param"):
|
377
|
+
if cur_cv != 0:
|
378
|
+
key = False
|
379
|
+
if key:
|
380
|
+
if key not in cv_results_:
|
381
|
+
cv_results_[key] = v
|
382
|
+
else:
|
383
|
+
cv_results_[key] = np.concatenate([cv_results_[key], v])
|
384
|
+
|
385
|
+
multimetric = len(scorers) > 1
|
386
|
+
# Use numpy to re-calculate all the information in cv_results_ again
|
387
|
+
# Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape,
|
388
|
+
# and average them by the idx_length;
|
389
|
+
# idx_length is the number of cv folds; params_length is the number of parameter combinations
|
390
|
+
scores = [
|
391
|
+
np.reshape(
|
392
|
+
np.concatenate([cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(idx_length)]),
|
393
|
+
(idx_length, -1),
|
394
|
+
)
|
395
|
+
for score in scorers
|
396
|
+
]
|
397
|
+
|
398
|
+
fit_score_test_matrix = np.stack(
|
399
|
+
[
|
400
|
+
np.reshape(cv_results_["mean_fit_time"], (idx_length, -1)),
|
401
|
+
np.reshape(cv_results_["mean_score_time"], (idx_length, -1)),
|
402
|
+
]
|
403
|
+
+ scores
|
404
|
+
)
|
405
|
+
|
406
|
+
mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1)
|
407
|
+
std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1)
|
408
|
+
cv_results_["std_fit_time"] = std_fit_score_test_matrix[0]
|
409
|
+
cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0]
|
410
|
+
cv_results_["std_score_time"] = std_fit_score_test_matrix[1]
|
411
|
+
cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1]
|
412
|
+
for idx, score in enumerate(scorers):
|
413
|
+
cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2]
|
414
|
+
cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2]
|
415
|
+
# re-compute the ranking again with mean_test_<score>.
|
416
|
+
cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min")
|
417
|
+
# The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared.
|
418
|
+
# If all scores are `nan`, `rankdata` will also produce an array of `nan` values.
|
419
|
+
# In that case, default to first index.
|
420
|
+
best_param_index = (
|
421
|
+
np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0]
|
422
|
+
if not np.isnan(cv_results_[f"rank_test_{score}"]).all()
|
423
|
+
else 0
|
424
|
+
)
|
425
|
+
|
426
|
+
estimator.cv_results_ = cv_results_
|
427
|
+
estimator.multimetric_ = multimetric
|
428
|
+
|
429
|
+
# Reconstruct the sklearn estimator.
|
430
|
+
refit_metric = "score"
|
431
|
+
if callable(estimator.scoring):
|
432
|
+
scorers = estimator.scoring
|
433
|
+
elif estimator.scoring is None or isinstance(estimator.scoring, str):
|
434
|
+
scorers = check_scoring(estimator.estimator, estimator.scoring)
|
435
|
+
else:
|
436
|
+
scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
|
437
|
+
estimator._check_refit_for_multimetric(scorers)
|
438
|
+
refit_metric = original_refit
|
439
|
+
|
440
|
+
estimator.scorer_ = scorers
|
441
|
+
|
442
|
+
# check refit_metric now for a callabe scorer that is multimetric
|
443
|
+
if callable(estimator.scoring) and estimator.multimetric_:
|
444
|
+
refit_metric = original_refit
|
445
|
+
|
446
|
+
# For multi-metric evaluation, store the best_index_, best_params_ and
|
447
|
+
# best_score_ iff refit is one of the scorer names
|
448
|
+
# In single metric evaluation, refit_metric is "score"
|
449
|
+
if original_refit or not estimator.multimetric_:
|
450
|
+
estimator.best_index_ = estimator._select_best_index(original_refit, refit_metric, cv_results_)
|
451
|
+
if not callable(original_refit):
|
452
|
+
# With a non-custom callable, we can select the best score
|
453
|
+
# based on the best index
|
454
|
+
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
455
|
+
estimator.best_params_ = cv_results_["params"][best_param_index]
|
456
|
+
|
457
|
+
if original_refit:
|
458
|
+
estimator.best_estimator_ = clone(estimator.estimator).set_params(
|
459
|
+
**clone(estimator.best_params_, safe=False)
|
460
|
+
)
|
461
|
+
|
462
|
+
# Let the sproc use all cores to refit.
|
463
|
+
estimator.n_jobs = -1 if not estimator.n_jobs else estimator.n_jobs
|
464
|
+
|
465
|
+
# process the input as args
|
466
|
+
argspec = inspect.getfullargspec(estimator.fit)
|
467
|
+
args = {"X": X}
|
468
|
+
if label_cols:
|
469
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
470
|
+
args[label_arg_name] = y
|
471
|
+
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
472
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
473
|
+
estimator.refit = original_refit
|
474
|
+
refit_start_time = time.time()
|
475
|
+
estimator.best_estimator_.fit(**args)
|
476
|
+
refit_end_time = time.time()
|
477
|
+
estimator.refit_time_ = refit_end_time - refit_start_time
|
478
|
+
|
479
|
+
if hasattr(estimator.best_estimator_, "feature_names_in_"):
|
480
|
+
estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
|
481
|
+
|
482
|
+
local_result_file_name = get_temp_file_path()
|
483
|
+
|
484
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
485
|
+
cp.dump(estimator, local_result_file_obj)
|
486
|
+
|
487
|
+
session.file.put(
|
488
|
+
local_result_file_name,
|
489
|
+
temp_stage_name,
|
490
|
+
auto_compress=False,
|
491
|
+
overwrite=True,
|
492
|
+
)
|
493
|
+
|
494
|
+
# Note: you can add something like + "|" + str(df) to the return string
|
495
|
+
# to pass debug information to the caller.
|
496
|
+
return str(os.path.basename(local_result_file_name))
|
497
|
+
|
498
|
+
sproc_export_file_name = _distributed_search(
|
499
|
+
session,
|
500
|
+
imports,
|
501
|
+
stage_estimator_file_name,
|
502
|
+
input_cols,
|
503
|
+
label_cols,
|
504
|
+
)
|
505
|
+
|
506
|
+
local_estimator_path = get_temp_file_path()
|
507
|
+
session.file.get(
|
508
|
+
posixpath.join(temp_stage_name, sproc_export_file_name),
|
509
|
+
local_estimator_path,
|
510
|
+
)
|
511
|
+
|
512
|
+
with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
|
513
|
+
fit_estimator = cp.load(result_file_obj)
|
514
|
+
|
515
|
+
cleanup_temp_files([local_estimator_path])
|
516
|
+
|
517
|
+
return fit_estimator
|
518
|
+
|
519
|
+
def train(self) -> object:
|
520
|
+
"""
|
521
|
+
Runs hyper parameter optimization by distributing the tasks across warehouse.
|
522
|
+
|
523
|
+
Returns:
|
524
|
+
Trained model
|
525
|
+
"""
|
526
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
527
|
+
assert isinstance(self.estimator, model_selection.GridSearchCV) or isinstance(
|
528
|
+
self.estimator, model_selection.RandomizedSearchCV
|
529
|
+
)
|
530
|
+
if hasattr(self.estimator.estimator, "n_jobs") and self.estimator.estimator.n_jobs in [
|
531
|
+
None,
|
532
|
+
-1,
|
533
|
+
]:
|
534
|
+
self.estimator.estimator.n_jobs = DEFAULT_UDTF_NJOBS
|
535
|
+
|
536
|
+
if isinstance(self.estimator, model_selection.GridSearchCV):
|
537
|
+
param_grid = model_selection.ParameterGrid(self.estimator.param_grid)
|
538
|
+
elif isinstance(self.estimator, model_selection.RandomizedSearchCV):
|
539
|
+
param_grid = model_selection.ParameterSampler(
|
540
|
+
self.estimator.param_distributions,
|
541
|
+
n_iter=self.estimator.n_iter,
|
542
|
+
random_state=self.estimator.random_state,
|
543
|
+
)
|
544
|
+
return self.fit_search_snowpark(
|
545
|
+
param_grid=param_grid,
|
546
|
+
dataset=self.dataset,
|
547
|
+
session=self.session,
|
548
|
+
estimator=self.estimator,
|
549
|
+
dependencies=model_spec.pkgDependencies,
|
550
|
+
udf_imports=["sklearn"],
|
551
|
+
input_cols=self.input_cols,
|
552
|
+
label_cols=self.label_cols,
|
553
|
+
sample_weight_col=self.sample_weight_col,
|
554
|
+
)
|
@@ -1,35 +1,12 @@
|
|
1
|
-
from typing import List, Optional, Protocol
|
1
|
+
from typing import List, Optional, Protocol
|
2
2
|
|
3
3
|
import pandas as pd
|
4
|
-
from sklearn import model_selection
|
5
4
|
|
6
5
|
from snowflake.snowpark import DataFrame, Session
|
7
6
|
|
8
7
|
|
9
8
|
# TODO: Add more specific entities to type hint estimators instead of using `object`.
|
10
9
|
class FitPredictHandlers(Protocol):
|
11
|
-
def fit_snowpark(
|
12
|
-
self,
|
13
|
-
dataset: DataFrame,
|
14
|
-
session: Session,
|
15
|
-
estimator: object,
|
16
|
-
dependencies: List[str],
|
17
|
-
input_cols: List[str],
|
18
|
-
label_cols: List[str],
|
19
|
-
sample_weight_col: Optional[str],
|
20
|
-
) -> object:
|
21
|
-
raise NotImplementedError
|
22
|
-
|
23
|
-
def fit_pandas(
|
24
|
-
self,
|
25
|
-
dataset: pd.DataFrame,
|
26
|
-
estimator: object,
|
27
|
-
input_cols: List[str],
|
28
|
-
label_cols: Optional[List[str]],
|
29
|
-
sample_weight_col: Optional[str],
|
30
|
-
) -> object:
|
31
|
-
raise NotImplementedError
|
32
|
-
|
33
10
|
def batch_inference(
|
34
11
|
self,
|
35
12
|
dataset: DataFrame,
|
@@ -70,28 +47,6 @@ class FitPredictHandlers(Protocol):
|
|
70
47
|
|
71
48
|
# TODO: Add more specific entities to type hint estimators instead of using `object`.
|
72
49
|
class CVHandlers(Protocol):
|
73
|
-
def fit_snowpark(
|
74
|
-
self,
|
75
|
-
dataset: DataFrame,
|
76
|
-
session: Session,
|
77
|
-
estimator: object,
|
78
|
-
dependencies: List[str],
|
79
|
-
input_cols: List[str],
|
80
|
-
label_cols: List[str],
|
81
|
-
sample_weight_col: Optional[str],
|
82
|
-
) -> object:
|
83
|
-
raise NotImplementedError
|
84
|
-
|
85
|
-
def fit_pandas(
|
86
|
-
self,
|
87
|
-
dataset: pd.DataFrame,
|
88
|
-
estimator: object,
|
89
|
-
input_cols: List[str],
|
90
|
-
label_cols: Optional[List[str]],
|
91
|
-
sample_weight_col: Optional[str],
|
92
|
-
) -> object:
|
93
|
-
raise NotImplementedError
|
94
|
-
|
95
50
|
def batch_inference(
|
96
51
|
self,
|
97
52
|
dataset: DataFrame,
|
@@ -128,17 +83,3 @@ class CVHandlers(Protocol):
|
|
128
83
|
sample_weight_col: Optional[str],
|
129
84
|
) -> float:
|
130
85
|
raise NotImplementedError
|
131
|
-
|
132
|
-
def fit_search_snowpark(
|
133
|
-
self,
|
134
|
-
param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
|
135
|
-
dataset: DataFrame,
|
136
|
-
session: Session,
|
137
|
-
estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
138
|
-
dependencies: List[str],
|
139
|
-
udf_imports: List[str],
|
140
|
-
input_cols: List[str],
|
141
|
-
label_cols: List[str],
|
142
|
-
sample_weight_col: Optional[str],
|
143
|
-
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
144
|
-
raise NotImplementedError
|