snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -8,11 +8,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
8
8
|
|
9
9
|
import cloudpickle as cp
|
10
10
|
import numpy as np
|
11
|
-
from scipy.stats import rankdata
|
12
11
|
from sklearn import model_selection
|
12
|
+
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
13
13
|
|
14
14
|
from snowflake.ml._internal import telemetry
|
15
|
-
from snowflake.ml._internal.utils import
|
15
|
+
from snowflake.ml._internal.utils import (
|
16
|
+
identifier,
|
17
|
+
pkg_version_utils,
|
18
|
+
snowpark_dataframe_utils,
|
19
|
+
)
|
16
20
|
from snowflake.ml._internal.utils.temp_file_utils import (
|
17
21
|
cleanup_temp_files,
|
18
22
|
get_temp_file_path,
|
@@ -26,7 +30,8 @@ from snowflake.snowpark._internal.utils import (
|
|
26
30
|
TempObjectType,
|
27
31
|
random_name_for_temp_object,
|
28
32
|
)
|
29
|
-
from snowflake.snowpark.functions import
|
33
|
+
from snowflake.snowpark.functions import sproc, udtf
|
34
|
+
from snowflake.snowpark.row import Row
|
30
35
|
from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
|
31
36
|
|
32
37
|
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
@@ -36,6 +41,117 @@ _PROJECT = "ModelDevelopment"
|
|
36
41
|
DEFAULT_UDTF_NJOBS = 3
|
37
42
|
|
38
43
|
|
44
|
+
def construct_cv_results(
|
45
|
+
estimator: Union[GridSearchCV, RandomizedSearchCV],
|
46
|
+
n_split: int,
|
47
|
+
param_grid: List[Dict[str, Any]],
|
48
|
+
cv_results_raw_hex: List[Row],
|
49
|
+
cross_validator_indices_length: int,
|
50
|
+
parameter_grid_length: int,
|
51
|
+
) -> Tuple[bool, Dict[str, Any]]:
|
52
|
+
"""Construct the cross validation result from the UDF. Because we accelerate the process
|
53
|
+
by the number of cross validation number, and the combination of parameter grids.
|
54
|
+
Therefore, we need to stick them back together instead of returning the raw result
|
55
|
+
to align with original sklearn result.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
estimator (Union[GridSearchCV, RandomizedSearchCV]): The sklearn object of estimator
|
59
|
+
GridSearchCV or RandomizedSearchCV
|
60
|
+
n_split (int): The number of split, which is determined by build_cross_validator.get_n_splits(X, y, groups)
|
61
|
+
param_grid (List[Dict[str, Any]]): the list of parameter grid or parameter sampler
|
62
|
+
cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination.
|
63
|
+
Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a
|
64
|
+
json format. Each cv_result is encoded into hex string.
|
65
|
+
cross_validator_indices_length (int): the length of cross validator indices
|
66
|
+
parameter_grid_length (int): the length of parameter grid combination
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
ValueError: Retrieved empty cross validation results
|
70
|
+
ValueError: Cross validator index length is 0
|
71
|
+
ValueError: Parameter index length is 0
|
72
|
+
ValueError: Retrieved incorrect dataframe dimension from Snowpark's UDTF.
|
73
|
+
RuntimeError: Cross validation results are unexpectedly empty for one fold.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
Tuple[bool, Dict[str, Any]]: returns multimetric, cv_results_
|
77
|
+
"""
|
78
|
+
# Filter corner cases: either the snowpark dataframe result is empty; or index length is empty
|
79
|
+
if len(cv_results_raw_hex) == 0:
|
80
|
+
raise ValueError(
|
81
|
+
"Retrieved empty cross validation results from snowpark. Please retry or contact snowflake support."
|
82
|
+
)
|
83
|
+
if cross_validator_indices_length == 0:
|
84
|
+
raise ValueError("Cross validator index length is 0. Was the CV iterator empty? ")
|
85
|
+
if parameter_grid_length == 0:
|
86
|
+
raise ValueError("Parameter index length is 0. Were there no candidates?")
|
87
|
+
|
88
|
+
# cv_result maintains the original order
|
89
|
+
multimetric = False
|
90
|
+
# retrieve the cv_results from udtf table; results are encoded by hex and cloudpickle;
|
91
|
+
# We are constructing the raw information back to original form
|
92
|
+
if len(cv_results_raw_hex) != cross_validator_indices_length * parameter_grid_length:
|
93
|
+
raise ValueError(
|
94
|
+
"Retrieved incorrect dataframe dimension from Snowpark's UDTF."
|
95
|
+
f"Expected {cross_validator_indices_length * parameter_grid_length}, got {len(cv_results_raw_hex)}. "
|
96
|
+
"Please retry or contact snowflake support."
|
97
|
+
)
|
98
|
+
|
99
|
+
out = []
|
100
|
+
|
101
|
+
for each_cv_result_hex in cv_results_raw_hex:
|
102
|
+
# convert the hex string back to cv_results_
|
103
|
+
hex_str = bytes.fromhex(each_cv_result_hex[0])
|
104
|
+
with io.BytesIO(hex_str) as f_reload:
|
105
|
+
each_cv_result = cp.load(f_reload)
|
106
|
+
if not each_cv_result:
|
107
|
+
raise RuntimeError(
|
108
|
+
"Cross validation response is empty. This issue may be temporary - please try again."
|
109
|
+
)
|
110
|
+
temp_dict = dict()
|
111
|
+
"""
|
112
|
+
This dictionary has the following keys
|
113
|
+
train_scores : dict of scorer name -> float
|
114
|
+
Score on training set (for all the scorers),
|
115
|
+
returned only if `return_train_score` is `True`.
|
116
|
+
test_scores : dict of scorer name -> float
|
117
|
+
Score on testing set (for all the scorers).
|
118
|
+
fit_time : float
|
119
|
+
Time spent for fitting in seconds.
|
120
|
+
score_time : float
|
121
|
+
Time spent for scoring in seconds.
|
122
|
+
"""
|
123
|
+
if estimator.return_train_score:
|
124
|
+
if each_cv_result.get("split0_train_score", None):
|
125
|
+
# for single scorer, the split0_train_score only contains an array with one value
|
126
|
+
temp_dict["train_scores"] = each_cv_result["split0_train_score"][0]
|
127
|
+
else:
|
128
|
+
# if multimetric situation, the format would be
|
129
|
+
# {metric_name1: value, metric_name2: value, ...}
|
130
|
+
temp_dict["train_scores"] = {}
|
131
|
+
# For multi-metric evaluation, the scores for all the scorers are available in the
|
132
|
+
# cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
|
133
|
+
# instead of '_score'.
|
134
|
+
for k, v in each_cv_result.items():
|
135
|
+
if "split0_train_" in k:
|
136
|
+
temp_dict["train_scores"][k[len("split0_train_") :]] = v
|
137
|
+
if isinstance(each_cv_result.get("split0_test_score"), np.ndarray):
|
138
|
+
temp_dict["test_scores"] = each_cv_result["split0_test_score"][0]
|
139
|
+
else:
|
140
|
+
temp_dict["test_scores"] = {}
|
141
|
+
for k, v in each_cv_result.items():
|
142
|
+
if "split0_test_" in k:
|
143
|
+
temp_dict["test_scores"][k[len("split0_test_") :]] = v
|
144
|
+
temp_dict["fit_time"] = each_cv_result["mean_fit_time"][0]
|
145
|
+
temp_dict["score_time"] = each_cv_result["mean_score_time"][0]
|
146
|
+
out.append(temp_dict)
|
147
|
+
first_test_score = out[0]["test_scores"]
|
148
|
+
multimetric = isinstance(first_test_score, dict)
|
149
|
+
return multimetric, estimator._format_results(param_grid, n_split, out)
|
150
|
+
|
151
|
+
|
152
|
+
cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
|
153
|
+
|
154
|
+
|
39
155
|
class DistributedHPOTrainer(SnowparkModelTrainer):
|
40
156
|
"""
|
41
157
|
A class for performing distributed hyperparameter optimization (HPO) using Snowpark.
|
@@ -105,7 +221,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
105
221
|
temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
|
106
222
|
session.sql(temp_stage_creation_query).collect()
|
107
223
|
|
108
|
-
# Stage data
|
224
|
+
# Stage data as parquet file
|
109
225
|
dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
|
110
226
|
remote_file_path = f"{temp_stage_name}/{temp_stage_name}.parquet"
|
111
227
|
dataset.write.copy_into_location( # type:ignore[call-overload]
|
@@ -114,6 +230,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
114
230
|
imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}").collect()]
|
115
231
|
|
116
232
|
# Store GridSearchCV's refit variable. If user set it as False, we don't need to refit it again
|
233
|
+
# refit variable can be boolean, string or callable
|
117
234
|
original_refit = estimator.refit
|
118
235
|
|
119
236
|
# Create a temp file and dump the estimator to that file.
|
@@ -136,7 +253,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
136
253
|
inspect.currentframe(), self.__class__.__name__
|
137
254
|
),
|
138
255
|
api_calls=[sproc],
|
139
|
-
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
140
256
|
)
|
141
257
|
udtf_statement_params = telemetry.get_function_usage_statement_params(
|
142
258
|
project=_PROJECT,
|
@@ -145,7 +261,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
145
261
|
inspect.currentframe(), self.__class__.__name__
|
146
262
|
),
|
147
263
|
api_calls=[udtf],
|
148
|
-
custom_tags=dict([("
|
264
|
+
custom_tags=dict([("hpo_udtf", True)]),
|
149
265
|
)
|
150
266
|
|
151
267
|
# Put locally serialized estimator on stage.
|
@@ -208,7 +324,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
208
324
|
for file_name in data_files
|
209
325
|
]
|
210
326
|
df = pd.concat(partial_df, ignore_index=True)
|
211
|
-
df.columns = [identifier.get_inferred_name(
|
327
|
+
df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
|
212
328
|
|
213
329
|
X = df[input_cols]
|
214
330
|
y = df[label_cols].squeeze() if label_cols else None
|
@@ -222,11 +338,16 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
222
338
|
with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
|
223
339
|
estimator = cp.load(local_estimator_file_obj)["estimator"]
|
224
340
|
|
225
|
-
|
226
|
-
|
341
|
+
build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
|
342
|
+
from sklearn.utils.validation import indexable
|
343
|
+
|
344
|
+
X, y, _ = indexable(X, y, None)
|
345
|
+
n_splits = build_cross_validator.get_n_splits(X, y, None)
|
346
|
+
# store the cross_validator's test indices only to save space
|
347
|
+
cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
|
227
348
|
local_indices_file_name = get_temp_file_path()
|
228
349
|
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
229
|
-
cp.dump(
|
350
|
+
cp.dump(cross_validator_indices, local_indices_file_obj)
|
230
351
|
|
231
352
|
# Put locally serialized indices on stage.
|
232
353
|
put_result = session.file.put(
|
@@ -237,7 +358,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
237
358
|
)
|
238
359
|
indices_location = put_result[0].target
|
239
360
|
imports.append(f"@{temp_stage_name}/{indices_location}")
|
240
|
-
|
361
|
+
cross_validator_indices_length = int(len(cross_validator_indices))
|
362
|
+
parameter_grid_length = len(param_grid)
|
241
363
|
|
242
364
|
assert estimator is not None
|
243
365
|
|
@@ -261,7 +383,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
261
383
|
for file_name in data_files
|
262
384
|
]
|
263
385
|
df = pd.concat(partial_df, ignore_index=True)
|
264
|
-
df.columns = [identifier.get_inferred_name(
|
386
|
+
df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
|
265
387
|
|
266
388
|
# load estimator
|
267
389
|
local_estimator_file_path = os.path.join(
|
@@ -299,16 +421,30 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
299
421
|
self.data_length = data_length
|
300
422
|
self.params_to_evaluate = params_to_evaluate
|
301
423
|
|
302
|
-
def process(self, params_idx: int,
|
424
|
+
def process(self, params_idx: int, cv_idx: int) -> Iterator[Tuple[str]]:
|
425
|
+
# Assign parameter to GridSearchCV
|
303
426
|
if hasattr(estimator, "param_grid"):
|
304
427
|
self.estimator.param_grid = self.params_to_evaluate[params_idx]
|
428
|
+
# Assign parameter to RandomizedSearchCV
|
305
429
|
else:
|
306
430
|
self.estimator.param_distributions = self.params_to_evaluate[params_idx]
|
431
|
+
# cross validator's indices: we stored test indices only (to save space);
|
432
|
+
# use the full indices to re-construct the train indices back.
|
307
433
|
full_indices = np.array([i for i in range(self.data_length)])
|
308
|
-
test_indice = self.indices[
|
434
|
+
test_indice = self.indices[cv_idx]
|
309
435
|
train_indice = np.setdiff1d(full_indices, test_indice)
|
436
|
+
# assign the tuple of train and test indices to estimator's original cross validator
|
310
437
|
self.estimator.cv = [(train_indice, test_indice)]
|
311
438
|
self.estimator.fit(**self.args)
|
439
|
+
# If the cv_results_ is empty, then the udtf table will have different number of output rows
|
440
|
+
# from the input rows. Raise ValueError.
|
441
|
+
if not self.estimator.cv_results_:
|
442
|
+
raise RuntimeError(
|
443
|
+
"""Cross validation results are unexpectedly empty for one fold.
|
444
|
+
This issue may be temporary - please try again."""
|
445
|
+
)
|
446
|
+
# Encode the dictionary of cv_results_ as binary (in hex format) to send it back
|
447
|
+
# because udtf doesn't allow numpy within json file
|
312
448
|
binary_cv_results = None
|
313
449
|
with io.BytesIO() as f:
|
314
450
|
cp.dump(self.estimator.cv_results_, f)
|
@@ -333,96 +469,44 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
333
469
|
|
334
470
|
HP_TUNING = F.table_function(random_udtf_name)
|
335
471
|
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
param_indices,
|
340
|
-
for param_idx, cv_idx in product(
|
472
|
+
# param_indices is for the index for each parameter grid;
|
473
|
+
# cv_indices is for the index for each cross_validator's fold;
|
474
|
+
# param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
|
475
|
+
param_indices, cv_indices = [], []
|
476
|
+
for param_idx, cv_idx in product(
|
477
|
+
[param_index for param_index in range(parameter_grid_length)],
|
478
|
+
[cv_index for cv_index in range(cross_validator_indices_length)],
|
479
|
+
):
|
341
480
|
param_indices.append(param_idx)
|
342
|
-
|
481
|
+
cv_indices.append(cv_idx)
|
343
482
|
|
344
|
-
|
483
|
+
indices_info_pandas = pd.DataFrame(
|
345
484
|
{
|
346
|
-
"
|
347
|
-
"
|
348
|
-
"
|
485
|
+
"PARAM_IND": param_indices,
|
486
|
+
"CV_IND": cv_indices,
|
487
|
+
"PARAM_CV_IND": [i for i in range(cross_validator_indices_length * parameter_grid_length)],
|
349
488
|
}
|
350
489
|
)
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
(
|
490
|
+
indices_info_sp = session.create_dataframe(indices_info_pandas)
|
491
|
+
# execute udtf by querying HP_TUNING table
|
492
|
+
HP_raw_results = indices_info_sp.select(
|
493
|
+
F.cast(indices_info_sp["PARAM_CV_IND"], IntegerType()).as_("PARAM_CV_IND"),
|
494
|
+
(
|
495
|
+
HP_TUNING(indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
|
496
|
+
partition_by=indices_info_sp["PARAM_CV_IND"]
|
497
|
+
)
|
498
|
+
),
|
355
499
|
)
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
hex_str = bytes.fromhex(val[0])
|
365
|
-
with io.BytesIO(hex_str) as f_reload:
|
366
|
-
each_cv_result = cp.load(f_reload)
|
367
|
-
for k, v in each_cv_result.items():
|
368
|
-
cur_cv = i % idx_length
|
369
|
-
key = k
|
370
|
-
if "split0_test_" in k:
|
371
|
-
# For multi-metric evaluation, the scores for all the scorers are available in the
|
372
|
-
# cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
|
373
|
-
# instead of '_score'.
|
374
|
-
scorers.add(k[len("split0_test_") :])
|
375
|
-
key = k.replace("split0_test", f"split{cur_cv}_test")
|
376
|
-
elif k.startswith("param"):
|
377
|
-
if cur_cv != 0:
|
378
|
-
key = False
|
379
|
-
if key:
|
380
|
-
if key not in cv_results_:
|
381
|
-
cv_results_[key] = v
|
382
|
-
else:
|
383
|
-
cv_results_[key] = np.concatenate([cv_results_[key], v])
|
384
|
-
|
385
|
-
multimetric = len(scorers) > 1
|
386
|
-
# Use numpy to re-calculate all the information in cv_results_ again
|
387
|
-
# Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape,
|
388
|
-
# and average them by the idx_length;
|
389
|
-
# idx_length is the number of cv folds; params_length is the number of parameter combinations
|
390
|
-
scores = [
|
391
|
-
np.reshape(
|
392
|
-
np.concatenate([cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(idx_length)]),
|
393
|
-
(idx_length, -1),
|
394
|
-
)
|
395
|
-
for score in scorers
|
396
|
-
]
|
397
|
-
|
398
|
-
fit_score_test_matrix = np.stack(
|
399
|
-
[
|
400
|
-
np.reshape(cv_results_["mean_fit_time"], (idx_length, -1)),
|
401
|
-
np.reshape(cv_results_["mean_score_time"], (idx_length, -1)),
|
402
|
-
]
|
403
|
-
+ scores
|
500
|
+
# multimetric, cv_results_, best_param_index, scorers
|
501
|
+
multimetric, cv_results_ = construct_cv_results(
|
502
|
+
estimator,
|
503
|
+
n_splits,
|
504
|
+
list(param_grid),
|
505
|
+
HP_raw_results.select("CV_RESULTS").sort(F.col("PARAM_CV_IND")).collect(),
|
506
|
+
cross_validator_indices_length,
|
507
|
+
parameter_grid_length,
|
404
508
|
)
|
405
509
|
|
406
|
-
mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1)
|
407
|
-
std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1)
|
408
|
-
cv_results_["std_fit_time"] = std_fit_score_test_matrix[0]
|
409
|
-
cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0]
|
410
|
-
cv_results_["std_score_time"] = std_fit_score_test_matrix[1]
|
411
|
-
cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1]
|
412
|
-
for idx, score in enumerate(scorers):
|
413
|
-
cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2]
|
414
|
-
cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2]
|
415
|
-
# re-compute the ranking again with mean_test_<score>.
|
416
|
-
cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min")
|
417
|
-
# The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared.
|
418
|
-
# If all scores are `nan`, `rankdata` will also produce an array of `nan` values.
|
419
|
-
# In that case, default to first index.
|
420
|
-
best_param_index = (
|
421
|
-
np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0]
|
422
|
-
if not np.isnan(cv_results_[f"rank_test_{score}"]).all()
|
423
|
-
else 0
|
424
|
-
)
|
425
|
-
|
426
510
|
estimator.cv_results_ = cv_results_
|
427
511
|
estimator.multimetric_ = multimetric
|
428
512
|
|
@@ -452,7 +536,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
452
536
|
# With a non-custom callable, we can select the best score
|
453
537
|
# based on the best index
|
454
538
|
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
455
|
-
estimator.best_params_ = cv_results_["params"][
|
539
|
+
estimator.best_params_ = cv_results_["params"][estimator.best_index_]
|
456
540
|
|
457
541
|
if original_refit:
|
458
542
|
estimator.best_estimator_ = clone(estimator.estimator).set_params(
|
@@ -541,12 +625,15 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
541
625
|
n_iter=self.estimator.n_iter,
|
542
626
|
random_state=self.estimator.random_state,
|
543
627
|
)
|
628
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
629
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
630
|
+
)
|
544
631
|
return self.fit_search_snowpark(
|
545
632
|
param_grid=param_grid,
|
546
633
|
dataset=self.dataset,
|
547
634
|
session=self.session,
|
548
635
|
estimator=self.estimator,
|
549
|
-
dependencies=
|
636
|
+
dependencies=relaxed_dependencies,
|
550
637
|
udf_imports=["sklearn"],
|
551
638
|
input_cols=self.input_cols,
|
552
639
|
label_cols=self.label_cols,
|
@@ -132,3 +132,24 @@ def is_single_node(session: Session) -> bool:
|
|
132
132
|
# If current session cannot retrieve the warehouse name back,
|
133
133
|
# Default as True; Let HPO fall back to stored procedure implementation
|
134
134
|
return True
|
135
|
+
|
136
|
+
|
137
|
+
def get_module_name(model: object) -> str:
|
138
|
+
"""Returns the source module of the given object.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
model: Object to inspect.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
Source module of the given object.
|
145
|
+
|
146
|
+
Raises:
|
147
|
+
SnowflakeMLException: If the source module of the given object is not found.
|
148
|
+
"""
|
149
|
+
module = inspect.getmodule(model)
|
150
|
+
if module is None:
|
151
|
+
raise exceptions.SnowflakeMLException(
|
152
|
+
error_code=error_codes.INVALID_TYPE,
|
153
|
+
original_exception=ValueError(f"Unable to infer the source module of the given object {model}."),
|
154
|
+
)
|
155
|
+
return module.__name__
|
@@ -1,10 +1,9 @@
|
|
1
|
-
import inspect
|
2
1
|
from typing import List
|
3
2
|
|
4
3
|
import cloudpickle as cp
|
5
4
|
import numpy as np
|
6
5
|
|
7
|
-
from snowflake.ml._internal.
|
6
|
+
from snowflake.ml.modeling._internal.estimator_utils import get_module_name
|
8
7
|
|
9
8
|
|
10
9
|
class ModelSpecifications:
|
@@ -120,16 +119,10 @@ class ModelSpecificationsBuilder:
|
|
120
119
|
Appropriate ModelSpecification object
|
121
120
|
|
122
121
|
Raises:
|
123
|
-
SnowflakeMLException: Raises an exception the module of given model can't be determined.
|
124
122
|
TypeError: Raises the exception for unsupported modules.
|
125
123
|
"""
|
126
|
-
|
127
|
-
|
128
|
-
raise exceptions.SnowflakeMLException(
|
129
|
-
error_code=error_codes.INVALID_TYPE,
|
130
|
-
original_exception=ValueError("Unable to infer model type of the given native model object."),
|
131
|
-
)
|
132
|
-
root_module_name = module.__name__.split(".")[0]
|
124
|
+
module_name = get_module_name(model=model)
|
125
|
+
root_module_name = module_name.split(".")[0]
|
133
126
|
if root_module_name == "sklearn":
|
134
127
|
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
135
128
|
|
@@ -3,13 +3,20 @@ from typing import List, Optional, Union
|
|
3
3
|
import pandas as pd
|
4
4
|
from sklearn import model_selection
|
5
5
|
|
6
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
6
7
|
from snowflake.ml.modeling._internal.distributed_hpo_trainer import (
|
7
8
|
DistributedHPOTrainer,
|
8
9
|
)
|
9
|
-
from snowflake.ml.modeling._internal.estimator_utils import
|
10
|
+
from snowflake.ml.modeling._internal.estimator_utils import (
|
11
|
+
get_module_name,
|
12
|
+
is_single_node,
|
13
|
+
)
|
10
14
|
from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
|
11
15
|
from snowflake.ml.modeling._internal.pandas_trainer import PandasModelTrainer
|
12
16
|
from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
|
17
|
+
from snowflake.ml.modeling._internal.xgboost_external_memory_trainer import (
|
18
|
+
XGBoostExternalMemoryTrainer,
|
19
|
+
)
|
13
20
|
from snowflake.snowpark import DataFrame, Session
|
14
21
|
|
15
22
|
_PROJECT = "ModelDevelopment"
|
@@ -30,6 +37,31 @@ class ModelTrainerBuilder:
|
|
30
37
|
def _check_if_distributed_hpo_enabled(cls, session: Session) -> bool:
|
31
38
|
return not is_single_node(session) and ModelTrainerBuilder._ENABLE_DISTRIBUTED is True
|
32
39
|
|
40
|
+
@classmethod
|
41
|
+
def _validate_external_memory_params(cls, estimator: object, batch_size: int) -> None:
|
42
|
+
"""
|
43
|
+
Validate the params are set appropriately for external memory training.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
estimator: Model object
|
47
|
+
batch_size: Number of rows in each batch of data processed during training.
|
48
|
+
|
49
|
+
Raises:
|
50
|
+
SnowflakeMLException: If the params are not appropriate for the external memory training feature.
|
51
|
+
"""
|
52
|
+
module_name = get_module_name(model=estimator)
|
53
|
+
root_module_name = module_name.split(".")[0]
|
54
|
+
if root_module_name != "xgboost":
|
55
|
+
raise exceptions.SnowflakeMLException(
|
56
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
57
|
+
original_exception=RuntimeError("External memory training is only supported for XGBoost models."),
|
58
|
+
)
|
59
|
+
if batch_size <= 0:
|
60
|
+
raise exceptions.SnowflakeMLException(
|
61
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
62
|
+
original_exception=RuntimeError("Batch size must be >= 0 when using external memory training feature."),
|
63
|
+
)
|
64
|
+
|
33
65
|
@classmethod
|
34
66
|
def build(
|
35
67
|
cls,
|
@@ -40,6 +72,8 @@ class ModelTrainerBuilder:
|
|
40
72
|
sample_weight_col: Optional[str] = None,
|
41
73
|
autogenerated: bool = False,
|
42
74
|
subproject: str = "",
|
75
|
+
use_external_memory_version: bool = False,
|
76
|
+
batch_size: int = -1,
|
43
77
|
) -> ModelTrainer:
|
44
78
|
"""
|
45
79
|
Builder method that creates an approproiate ModelTrainer instance based on the given params.
|
@@ -55,22 +89,32 @@ class ModelTrainerBuilder:
|
|
55
89
|
)
|
56
90
|
elif isinstance(dataset, DataFrame):
|
57
91
|
trainer_klass = SnowparkModelTrainer
|
92
|
+
init_args = {
|
93
|
+
"estimator": estimator,
|
94
|
+
"dataset": dataset,
|
95
|
+
"session": dataset._session,
|
96
|
+
"input_cols": input_cols,
|
97
|
+
"label_cols": label_cols,
|
98
|
+
"sample_weight_col": sample_weight_col,
|
99
|
+
"autogenerated": autogenerated,
|
100
|
+
"subproject": subproject,
|
101
|
+
}
|
102
|
+
|
58
103
|
assert dataset._session is not None # Make MyPy happpy
|
59
104
|
if isinstance(estimator, model_selection.GridSearchCV) or isinstance(
|
60
105
|
estimator, model_selection.RandomizedSearchCV
|
61
106
|
):
|
62
107
|
if ModelTrainerBuilder._check_if_distributed_hpo_enabled(session=dataset._session):
|
63
108
|
trainer_klass = DistributedHPOTrainer
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
)
|
109
|
+
elif use_external_memory_version:
|
110
|
+
ModelTrainerBuilder._validate_external_memory_params(
|
111
|
+
estimator=estimator,
|
112
|
+
batch_size=batch_size,
|
113
|
+
)
|
114
|
+
trainer_klass = XGBoostExternalMemoryTrainer
|
115
|
+
init_args["batch_size"] = batch_size
|
116
|
+
|
117
|
+
return trainer_klass(**init_args) # type: ignore[arg-type]
|
74
118
|
else:
|
75
119
|
raise TypeError(
|
76
120
|
f"Unexpected dataset type: {type(dataset)}."
|
@@ -306,7 +306,7 @@ class SnowparkHandlers:
|
|
306
306
|
input_cols: List[str],
|
307
307
|
label_cols: List[str],
|
308
308
|
sample_weight_col: Optional[str],
|
309
|
-
|
309
|
+
score_statement_params: Dict[str, str],
|
310
310
|
) -> float:
|
311
311
|
import inspect
|
312
312
|
import os
|
@@ -317,13 +317,13 @@ class SnowparkHandlers:
|
|
317
317
|
importlib.import_module(import_name)
|
318
318
|
|
319
319
|
for query in sql_queries[:-1]:
|
320
|
-
_ = session.sql(query).collect(statement_params=
|
320
|
+
_ = session.sql(query).collect(statement_params=score_statement_params)
|
321
321
|
sp_df = session.sql(sql_queries[-1])
|
322
|
-
df: pd.DataFrame = sp_df.to_pandas(statement_params=
|
322
|
+
df: pd.DataFrame = sp_df.to_pandas(statement_params=score_statement_params)
|
323
323
|
df.columns = sp_df.columns
|
324
324
|
|
325
325
|
local_score_file_name = get_temp_file_path()
|
326
|
-
session.file.get(stage_score_file_name, local_score_file_name, statement_params=
|
326
|
+
session.file.get(stage_score_file_name, local_score_file_name, statement_params=score_statement_params)
|
327
327
|
|
328
328
|
local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
|
329
329
|
with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
|
@@ -348,7 +348,7 @@ class SnowparkHandlers:
|
|
348
348
|
return result
|
349
349
|
|
350
350
|
# Call score sproc
|
351
|
-
|
351
|
+
score_statement_params = telemetry.get_function_usage_statement_params(
|
352
352
|
project=_PROJECT,
|
353
353
|
subproject=self._subproject,
|
354
354
|
function_name=telemetry.get_statement_params_full_func_name(
|
@@ -357,6 +357,8 @@ class SnowparkHandlers:
|
|
357
357
|
api_calls=[Session.call],
|
358
358
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
359
359
|
)
|
360
|
+
|
361
|
+
kwargs = telemetry.get_sproc_statement_params_kwargs(score_wrapper_sproc, score_statement_params)
|
360
362
|
score: float = score_wrapper_sproc(
|
361
363
|
session,
|
362
364
|
queries,
|
@@ -364,7 +366,8 @@ class SnowparkHandlers:
|
|
364
366
|
input_cols,
|
365
367
|
label_cols,
|
366
368
|
sample_weight_col,
|
367
|
-
|
369
|
+
score_statement_params,
|
370
|
+
**kwargs,
|
368
371
|
)
|
369
372
|
|
370
373
|
cleanup_temp_files([local_score_file_name])
|
@@ -12,7 +12,11 @@ from snowflake.ml._internal.exceptions import (
|
|
12
12
|
exceptions,
|
13
13
|
modeling_error_messages,
|
14
14
|
)
|
15
|
-
from snowflake.ml._internal.utils import
|
15
|
+
from snowflake.ml._internal.utils import (
|
16
|
+
identifier,
|
17
|
+
pkg_version_utils,
|
18
|
+
snowpark_dataframe_utils,
|
19
|
+
)
|
16
20
|
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
|
17
21
|
from snowflake.ml._internal.utils.temp_file_utils import (
|
18
22
|
cleanup_temp_files,
|
@@ -253,11 +257,15 @@ class SnowparkModelTrainer:
|
|
253
257
|
|
254
258
|
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
255
259
|
|
260
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
261
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
262
|
+
)
|
263
|
+
|
256
264
|
fit_wrapper_sproc = self.session.sproc.register(
|
257
265
|
func=self._build_fit_wrapper_sproc(model_spec=model_spec),
|
258
266
|
is_permanent=False,
|
259
267
|
name=fit_sproc_name,
|
260
|
-
packages=["snowflake-snowpark-python"] +
|
268
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
261
269
|
replace=True,
|
262
270
|
session=self.session,
|
263
271
|
statement_params=statement_params,
|