snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -0,0 +1,444 @@
|
|
1
|
+
import inspect
|
2
|
+
import os
|
3
|
+
import tempfile
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
|
+
|
6
|
+
import cloudpickle as cp
|
7
|
+
import pandas as pd
|
8
|
+
import pyarrow.parquet as pq
|
9
|
+
|
10
|
+
from snowflake.ml._internal import telemetry
|
11
|
+
from snowflake.ml._internal.exceptions import (
|
12
|
+
error_codes,
|
13
|
+
exceptions,
|
14
|
+
modeling_error_messages,
|
15
|
+
)
|
16
|
+
from snowflake.ml._internal.utils import pkg_version_utils
|
17
|
+
from snowflake.ml._internal.utils.query_result_checker import ResultValidator
|
18
|
+
from snowflake.ml._internal.utils.snowpark_dataframe_utils import (
|
19
|
+
cast_snowpark_dataframe,
|
20
|
+
)
|
21
|
+
from snowflake.ml._internal.utils.temp_file_utils import get_temp_file_path
|
22
|
+
from snowflake.ml.modeling._internal.model_specifications import (
|
23
|
+
ModelSpecifications,
|
24
|
+
ModelSpecificationsBuilder,
|
25
|
+
)
|
26
|
+
from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
|
27
|
+
from snowflake.snowpark import (
|
28
|
+
DataFrame,
|
29
|
+
Session,
|
30
|
+
exceptions as snowpark_exceptions,
|
31
|
+
functions as F,
|
32
|
+
)
|
33
|
+
from snowflake.snowpark._internal.utils import (
|
34
|
+
TempObjectType,
|
35
|
+
random_name_for_temp_object,
|
36
|
+
)
|
37
|
+
|
38
|
+
_PROJECT = "ModelDevelopment"
|
39
|
+
|
40
|
+
|
41
|
+
def get_data_iterator(
|
42
|
+
file_paths: List[str],
|
43
|
+
batch_size: int,
|
44
|
+
input_cols: List[str],
|
45
|
+
label_cols: List[str],
|
46
|
+
sample_weight_col: Optional[str] = None,
|
47
|
+
) -> Any:
|
48
|
+
from typing import List, Optional
|
49
|
+
|
50
|
+
import xgboost
|
51
|
+
|
52
|
+
class ParquetDataIterator(xgboost.DataIter):
|
53
|
+
"""
|
54
|
+
This iterator reads parquet data stored in a specified files and returns
|
55
|
+
deserialized data, enabling seamless integration with the xgboost framework for
|
56
|
+
machine learning tasks.
|
57
|
+
"""
|
58
|
+
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
file_paths: List[str],
|
62
|
+
batch_size: int,
|
63
|
+
input_cols: List[str],
|
64
|
+
label_cols: List[str],
|
65
|
+
sample_weight_col: Optional[str] = None,
|
66
|
+
) -> None:
|
67
|
+
"""
|
68
|
+
Initialize the DataIterator.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
file_paths: List of file paths containing the data.
|
72
|
+
batch_size: Target number of rows in each batch.
|
73
|
+
input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for
|
74
|
+
training.
|
75
|
+
label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s)
|
76
|
+
to learn.
|
77
|
+
sample_weight_col: The column name representing the weight of training examples.
|
78
|
+
"""
|
79
|
+
self._file_paths = file_paths
|
80
|
+
self._batch_size = batch_size
|
81
|
+
self._input_cols = input_cols
|
82
|
+
self._label_cols = label_cols
|
83
|
+
self._sample_weight_col = sample_weight_col
|
84
|
+
|
85
|
+
# File index
|
86
|
+
self._it = 0
|
87
|
+
# Pandas dataframe containing temp data
|
88
|
+
self._df = None
|
89
|
+
# XGBoost will generate some cache files under current directory with the prefix
|
90
|
+
# "cache"
|
91
|
+
cache_dir_name = tempfile.mkdtemp()
|
92
|
+
super().__init__(cache_prefix=os.path.join(cache_dir_name, "cache"))
|
93
|
+
|
94
|
+
def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def]
|
95
|
+
"""Advance the iterator by 1 step and pass the data to XGBoost's batch_consumer_fn.
|
96
|
+
This function is called by XGBoost during the construction of ``DMatrix``
|
97
|
+
|
98
|
+
Args:
|
99
|
+
batch_consumer_fn: batch consumer function
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
0 if there is no more data, else 1.
|
103
|
+
"""
|
104
|
+
while (self._df is None) or (self._df.shape[0] < self._batch_size):
|
105
|
+
# Read files and append data to temp df until batch size is reached.
|
106
|
+
if self._it == len(self._file_paths):
|
107
|
+
break
|
108
|
+
new_df = pq.read_table(self._file_paths[self._it]).to_pandas()
|
109
|
+
self._it += 1
|
110
|
+
|
111
|
+
if self._df is None:
|
112
|
+
self._df = new_df
|
113
|
+
else:
|
114
|
+
self._df = pd.concat([self._df, new_df], ignore_index=True)
|
115
|
+
|
116
|
+
if (self._df is None) or (self._df.shape[0] == 0):
|
117
|
+
# No more data
|
118
|
+
return 0
|
119
|
+
|
120
|
+
# Slice the temp df and save the remainder in the temp df
|
121
|
+
batch_end_index = min(self._batch_size, self._df.shape[0])
|
122
|
+
batch_df = self._df.iloc[:batch_end_index]
|
123
|
+
self._df = self._df.truncate(before=batch_end_index).reset_index(drop=True)
|
124
|
+
|
125
|
+
# TODO(snandamuri): Make it proper to support categorical features, etc.
|
126
|
+
func_args = {
|
127
|
+
"data": batch_df[self._input_cols],
|
128
|
+
"label": batch_df[self._label_cols].squeeze(),
|
129
|
+
}
|
130
|
+
if self._sample_weight_col is not None:
|
131
|
+
func_args["weight"] = batch_df[self._sample_weight_col].squeeze()
|
132
|
+
|
133
|
+
batch_consumer_fn(**func_args)
|
134
|
+
# Return 1 to let XGBoost know we haven't seen all the files yet.
|
135
|
+
return 1
|
136
|
+
|
137
|
+
def reset(self) -> None:
|
138
|
+
"""Reset the iterator to its beginning"""
|
139
|
+
self._it = 0
|
140
|
+
|
141
|
+
return ParquetDataIterator(
|
142
|
+
file_paths=file_paths,
|
143
|
+
batch_size=batch_size,
|
144
|
+
input_cols=input_cols,
|
145
|
+
label_cols=label_cols,
|
146
|
+
sample_weight_col=sample_weight_col,
|
147
|
+
)
|
148
|
+
|
149
|
+
|
150
|
+
def train_xgboost_model(
|
151
|
+
estimator: object,
|
152
|
+
file_paths: List[str],
|
153
|
+
batch_size: int,
|
154
|
+
input_cols: List[str],
|
155
|
+
label_cols: List[str],
|
156
|
+
sample_weight_col: Optional[str] = None,
|
157
|
+
) -> object:
|
158
|
+
"""
|
159
|
+
Function to train XGBoost models using the external memory version of XGBoost.
|
160
|
+
"""
|
161
|
+
import xgboost
|
162
|
+
|
163
|
+
def _objective_decorator(func): # type: ignore[no-untyped-def]
|
164
|
+
def inner(preds, dmatrix): # type: ignore[no-untyped-def]
|
165
|
+
"""internal function"""
|
166
|
+
labels = dmatrix.get_label()
|
167
|
+
return func(labels, preds)
|
168
|
+
|
169
|
+
return inner
|
170
|
+
|
171
|
+
assert isinstance(estimator, xgboost.XGBModel)
|
172
|
+
params = estimator.get_xgb_params()
|
173
|
+
obj = None
|
174
|
+
|
175
|
+
if isinstance(estimator, xgboost.XGBClassifier):
|
176
|
+
# TODO (snandamuri): Find better way to get expected_classes
|
177
|
+
# Set: self.classes_, self.n_classes_
|
178
|
+
expected_classes = pd.unique(pq.read_table(file_paths[0]).to_pandas()[label_cols].squeeze())
|
179
|
+
estimator.n_classes_ = len(expected_classes)
|
180
|
+
if callable(estimator.objective):
|
181
|
+
obj = _objective_decorator(estimator.objective) # type: ignore[no-untyped-call]
|
182
|
+
# Use default value. Is it really not used ?
|
183
|
+
params["objective"] = "binary:logistic"
|
184
|
+
|
185
|
+
if len(expected_classes) > 2:
|
186
|
+
# Switch to using a multiclass objective in the underlying XGB instance
|
187
|
+
if params.get("objective", None) != "multi:softmax":
|
188
|
+
params["objective"] = "multi:softprob"
|
189
|
+
params["num_class"] = len(expected_classes)
|
190
|
+
|
191
|
+
if "tree_method" not in params.keys() or params["tree_method"] is None or params["tree_method"].lower() == "exact":
|
192
|
+
params["tree_method"] = "hist"
|
193
|
+
|
194
|
+
if (
|
195
|
+
"grow_policy" not in params.keys()
|
196
|
+
or params["grow_policy"] is None
|
197
|
+
or params["grow_policy"].lower() != "depthwise"
|
198
|
+
):
|
199
|
+
params["grow_policy"] = "depthwise"
|
200
|
+
|
201
|
+
it = get_data_iterator(
|
202
|
+
file_paths=file_paths,
|
203
|
+
batch_size=batch_size,
|
204
|
+
input_cols=input_cols,
|
205
|
+
label_cols=label_cols,
|
206
|
+
sample_weight_col=sample_weight_col,
|
207
|
+
)
|
208
|
+
Xy = xgboost.DMatrix(it)
|
209
|
+
estimator._Booster = xgboost.train(
|
210
|
+
params,
|
211
|
+
Xy,
|
212
|
+
estimator.get_num_boosting_rounds(),
|
213
|
+
evals=[],
|
214
|
+
early_stopping_rounds=estimator.early_stopping_rounds,
|
215
|
+
evals_result=None,
|
216
|
+
obj=obj,
|
217
|
+
custom_metric=estimator.eval_metric,
|
218
|
+
verbose_eval=None,
|
219
|
+
xgb_model=None,
|
220
|
+
callbacks=None,
|
221
|
+
)
|
222
|
+
return estimator
|
223
|
+
|
224
|
+
|
225
|
+
cp.register_pickle_by_value(inspect.getmodule(get_data_iterator))
|
226
|
+
cp.register_pickle_by_value(inspect.getmodule(train_xgboost_model))
|
227
|
+
|
228
|
+
|
229
|
+
class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
|
230
|
+
"""
|
231
|
+
When working with large datasets, training XGBoost models traditionally requires loading the entire dataset into
|
232
|
+
memory, which can be costly and sometimes infeasible due to memory constraints. To solve this problem, XGBoost
|
233
|
+
provides support for loading data from external memory using a built-in data parser. With this feature enabled,
|
234
|
+
the training process occurs in a two-step approach:
|
235
|
+
Preprocessing Step: Input data is read and parsed into an internal format, such as CSR, CSC, or sorted CSC.
|
236
|
+
Processed state is appended to an in-memory buffer. Once the buffer reaches a predefined size, it is
|
237
|
+
written out to disk as a page.
|
238
|
+
Tree Construction Step: During the tree construction phase, the data pages stored on disk are streamed via
|
239
|
+
a multi-threaded pre-fetcher, allowing the model to efficiently access and process the data without
|
240
|
+
overloading memory.
|
241
|
+
"""
|
242
|
+
|
243
|
+
def __init__(
|
244
|
+
self,
|
245
|
+
estimator: object,
|
246
|
+
dataset: DataFrame,
|
247
|
+
session: Session,
|
248
|
+
input_cols: List[str],
|
249
|
+
label_cols: Optional[List[str]],
|
250
|
+
sample_weight_col: Optional[str],
|
251
|
+
autogenerated: bool = False,
|
252
|
+
subproject: str = "",
|
253
|
+
batch_size: int = 10000,
|
254
|
+
) -> None:
|
255
|
+
"""
|
256
|
+
Initializes the XGBoostExternalMemoryTrainer with a model, a Snowpark DataFrame, feature, and label column
|
257
|
+
names, etc.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
estimator: SKLearn compatible estimator or transformer object.
|
261
|
+
dataset: The dataset used for training the model.
|
262
|
+
session: Snowflake session object to be used for training.
|
263
|
+
input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training.
|
264
|
+
label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn.
|
265
|
+
sample_weight_col: The column name representing the weight of training examples.
|
266
|
+
autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not.
|
267
|
+
subproject: subproject name to be used in telemetry.
|
268
|
+
batch_size: Number of the rows in the each batch processed during training.
|
269
|
+
"""
|
270
|
+
super().__init__(
|
271
|
+
estimator=estimator,
|
272
|
+
dataset=dataset,
|
273
|
+
session=session,
|
274
|
+
input_cols=input_cols,
|
275
|
+
label_cols=label_cols,
|
276
|
+
sample_weight_col=sample_weight_col,
|
277
|
+
autogenerated=autogenerated,
|
278
|
+
subproject=subproject,
|
279
|
+
)
|
280
|
+
self._batch_size = batch_size
|
281
|
+
|
282
|
+
def _get_xgb_external_memory_fit_wrapper_sproc(
|
283
|
+
self,
|
284
|
+
model_spec: ModelSpecifications,
|
285
|
+
session: Session,
|
286
|
+
statement_params: Dict[str, str],
|
287
|
+
import_file_paths: List[str],
|
288
|
+
) -> Any:
|
289
|
+
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
290
|
+
|
291
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
292
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
293
|
+
)
|
294
|
+
|
295
|
+
@F.sproc(
|
296
|
+
is_permanent=False,
|
297
|
+
name=fit_sproc_name,
|
298
|
+
packages=list(["snowflake-snowpark-python"] + relaxed_dependencies),
|
299
|
+
replace=True,
|
300
|
+
session=session,
|
301
|
+
statement_params=statement_params,
|
302
|
+
anonymous=True,
|
303
|
+
imports=list(import_file_paths),
|
304
|
+
) # type: ignore[misc]
|
305
|
+
def fit_wrapper_sproc(
|
306
|
+
session: Session,
|
307
|
+
stage_transform_file_name: str,
|
308
|
+
stage_result_file_name: str,
|
309
|
+
dataset_stage_name: str,
|
310
|
+
batch_size: int,
|
311
|
+
input_cols: List[str],
|
312
|
+
label_cols: List[str],
|
313
|
+
sample_weight_col: Optional[str],
|
314
|
+
statement_params: Dict[str, str],
|
315
|
+
) -> str:
|
316
|
+
import os
|
317
|
+
import sys
|
318
|
+
|
319
|
+
import cloudpickle as cp
|
320
|
+
|
321
|
+
local_transform_file_name = get_temp_file_path()
|
322
|
+
|
323
|
+
session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
|
324
|
+
|
325
|
+
local_transform_file_path = os.path.join(
|
326
|
+
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
327
|
+
)
|
328
|
+
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
329
|
+
estimator = cp.load(local_transform_file_obj)
|
330
|
+
|
331
|
+
data_files = [
|
332
|
+
os.path.join(sys._xoptions["snowflake_import_directory"], filename)
|
333
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
334
|
+
if filename.startswith(dataset_stage_name)
|
335
|
+
]
|
336
|
+
|
337
|
+
estimator = train_xgboost_model(
|
338
|
+
estimator=estimator,
|
339
|
+
file_paths=data_files,
|
340
|
+
batch_size=batch_size,
|
341
|
+
input_cols=input_cols,
|
342
|
+
label_cols=label_cols,
|
343
|
+
sample_weight_col=sample_weight_col,
|
344
|
+
)
|
345
|
+
|
346
|
+
local_result_file_name = get_temp_file_path()
|
347
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
348
|
+
cp.dump(estimator, local_result_file_obj)
|
349
|
+
|
350
|
+
session.file.put(
|
351
|
+
local_result_file_name,
|
352
|
+
stage_result_file_name,
|
353
|
+
auto_compress=False,
|
354
|
+
overwrite=True,
|
355
|
+
statement_params=statement_params,
|
356
|
+
)
|
357
|
+
|
358
|
+
# Note: you can add something like + "|" + str(df) to the return string
|
359
|
+
# to pass debug information to the caller.
|
360
|
+
return str(os.path.basename(local_result_file_name))
|
361
|
+
|
362
|
+
return fit_wrapper_sproc
|
363
|
+
|
364
|
+
def _write_training_data_to_stage(self, dataset_stage_name: str) -> List[str]:
|
365
|
+
"""
|
366
|
+
Materializes the training to the specified stage and returns the list of stage file paths.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
dataset_stage_name: Target stage to materialize training data.
|
370
|
+
|
371
|
+
Returns:
|
372
|
+
List of stage file paths that contain the materialized data.
|
373
|
+
"""
|
374
|
+
# Stage data.
|
375
|
+
dataset = cast_snowpark_dataframe(self.dataset)
|
376
|
+
remote_file_path = f"{dataset_stage_name}/{dataset_stage_name}.parquet"
|
377
|
+
copy_response = dataset.write.copy_into_location( # type:ignore[call-overload]
|
378
|
+
remote_file_path, file_format_type="parquet", header=True, overwrite=True
|
379
|
+
)
|
380
|
+
ResultValidator(result=copy_response).has_dimensions(expected_rows=1).validate()
|
381
|
+
data_file_paths = [f"@{row.name}" for row in self.session.sql(f"LIST @{dataset_stage_name}").collect()]
|
382
|
+
return data_file_paths
|
383
|
+
|
384
|
+
def train(self) -> object:
|
385
|
+
"""
|
386
|
+
Runs hyper parameter optimization by distributing the tasks across warehouse.
|
387
|
+
|
388
|
+
Returns:
|
389
|
+
Trained model
|
390
|
+
|
391
|
+
Raises:
|
392
|
+
SnowflakeMLException: For known types of user and system errors.
|
393
|
+
e: For every unexpected exception from SnowflakeClient.
|
394
|
+
"""
|
395
|
+
temp_stage_name = self._create_temp_stage()
|
396
|
+
(stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(stage_name=temp_stage_name)
|
397
|
+
data_file_paths = self._write_training_data_to_stage(dataset_stage_name=temp_stage_name)
|
398
|
+
|
399
|
+
# Call fit sproc
|
400
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
401
|
+
project=_PROJECT,
|
402
|
+
subproject=self._subproject,
|
403
|
+
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
404
|
+
api_calls=[Session.call],
|
405
|
+
custom_tags=None,
|
406
|
+
)
|
407
|
+
|
408
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
409
|
+
fit_wrapper = self._get_xgb_external_memory_fit_wrapper_sproc(
|
410
|
+
model_spec=model_spec,
|
411
|
+
session=self.session,
|
412
|
+
statement_params=statement_params,
|
413
|
+
import_file_paths=data_file_paths,
|
414
|
+
)
|
415
|
+
|
416
|
+
try:
|
417
|
+
sproc_export_file_name = fit_wrapper(
|
418
|
+
self.session,
|
419
|
+
stage_transform_file_name,
|
420
|
+
stage_result_file_name,
|
421
|
+
temp_stage_name,
|
422
|
+
self._batch_size,
|
423
|
+
self.input_cols,
|
424
|
+
self.label_cols,
|
425
|
+
self.sample_weight_col,
|
426
|
+
statement_params,
|
427
|
+
)
|
428
|
+
except snowpark_exceptions.SnowparkClientException as e:
|
429
|
+
if "fit() missing 1 required positional argument: 'y'" in str(e):
|
430
|
+
raise exceptions.SnowflakeMLException(
|
431
|
+
error_code=error_codes.NOT_FOUND,
|
432
|
+
original_exception=RuntimeError(modeling_error_messages.ATTRIBUTE_NOT_SET.format("label_cols")),
|
433
|
+
) from e
|
434
|
+
raise e
|
435
|
+
|
436
|
+
if "|" in sproc_export_file_name:
|
437
|
+
fields = sproc_export_file_name.strip().split("|")
|
438
|
+
sproc_export_file_name = fields[0]
|
439
|
+
|
440
|
+
return self._fetch_model_from_stage(
|
441
|
+
dir_path=stage_result_file_name,
|
442
|
+
file_name=sproc_export_file_name,
|
443
|
+
statement_params=statement_params,
|
444
|
+
)
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.calibration".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class CalibratedClassifierCV(BaseTransformer):
|
58
70
|
r"""Probability calibration with isotonic regression or logistic regression
|
59
71
|
For more details on this class, see [sklearn.calibration.CalibratedClassifierCV]
|
@@ -192,7 +204,9 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
192
204
|
self.set_label_cols(label_cols)
|
193
205
|
self.set_passthrough_cols(passthrough_cols)
|
194
206
|
self.set_drop_input_cols(drop_input_cols)
|
195
|
-
self.set_sample_weight_col(sample_weight_col)
|
207
|
+
self.set_sample_weight_col(sample_weight_col)
|
208
|
+
self._use_external_memory_version = False
|
209
|
+
self._batch_size = -1
|
196
210
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
197
211
|
deps = deps | gather_dependencies(estimator)
|
198
212
|
deps = deps | gather_dependencies(base_estimator)
|
@@ -275,11 +289,6 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
275
289
|
if isinstance(dataset, DataFrame):
|
276
290
|
session = dataset._session
|
277
291
|
assert session is not None # keep mypy happy
|
278
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
279
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
280
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
281
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
282
|
-
|
283
292
|
# Specify input columns so column pruning will be enforced
|
284
293
|
selected_cols = self._get_active_columns()
|
285
294
|
if len(selected_cols) > 0:
|
@@ -307,7 +316,9 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
307
316
|
label_cols=self.label_cols,
|
308
317
|
sample_weight_col=self.sample_weight_col,
|
309
318
|
autogenerated=self._autogenerated,
|
310
|
-
subproject=_SUBPROJECT
|
319
|
+
subproject=_SUBPROJECT,
|
320
|
+
use_external_memory_version=self._use_external_memory_version,
|
321
|
+
batch_size=self._batch_size,
|
311
322
|
)
|
312
323
|
self._sklearn_object = model_trainer.train()
|
313
324
|
self._is_fitted = True
|
@@ -578,6 +589,22 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
578
589
|
# each row containing a list of values.
|
579
590
|
expected_dtype = "ARRAY"
|
580
591
|
|
592
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
593
|
+
if expected_dtype == "":
|
594
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
595
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
596
|
+
expected_dtype = "ARRAY"
|
597
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
598
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
599
|
+
expected_dtype = "ARRAY"
|
600
|
+
else:
|
601
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
602
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
603
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
604
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
605
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
606
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
607
|
+
|
581
608
|
output_df = self._batch_inference(
|
582
609
|
dataset=dataset,
|
583
610
|
inference_method="transform",
|
@@ -593,8 +620,8 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
593
620
|
|
594
621
|
return output_df
|
595
622
|
|
596
|
-
@available_if(
|
597
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
623
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
624
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
598
625
|
""" Method not supported for this class.
|
599
626
|
|
600
627
|
|
@@ -607,13 +634,21 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
607
634
|
Returns:
|
608
635
|
Predicted dataset.
|
609
636
|
"""
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
637
|
+
self.fit(dataset)
|
638
|
+
assert self._sklearn_object is not None
|
639
|
+
return self._sklearn_object.labels_
|
640
|
+
|
641
|
+
|
642
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
643
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
644
|
+
"""
|
645
|
+
Returns:
|
646
|
+
Transformed dataset.
|
647
|
+
"""
|
648
|
+
self.fit(dataset)
|
649
|
+
assert self._sklearn_object is not None
|
650
|
+
return self._sklearn_object.embedding_
|
651
|
+
|
617
652
|
|
618
653
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
619
654
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class AffinityPropagation(BaseTransformer):
|
58
70
|
r"""Perform Affinity Propagation Clustering of data
|
59
71
|
For more details on this class, see [sklearn.cluster.AffinityPropagation]
|
@@ -167,7 +179,9 @@ class AffinityPropagation(BaseTransformer):
|
|
167
179
|
self.set_label_cols(label_cols)
|
168
180
|
self.set_passthrough_cols(passthrough_cols)
|
169
181
|
self.set_drop_input_cols(drop_input_cols)
|
170
|
-
self.set_sample_weight_col(sample_weight_col)
|
182
|
+
self.set_sample_weight_col(sample_weight_col)
|
183
|
+
self._use_external_memory_version = False
|
184
|
+
self._batch_size = -1
|
171
185
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
172
186
|
|
173
187
|
self._deps = list(deps)
|
@@ -250,11 +264,6 @@ class AffinityPropagation(BaseTransformer):
|
|
250
264
|
if isinstance(dataset, DataFrame):
|
251
265
|
session = dataset._session
|
252
266
|
assert session is not None # keep mypy happy
|
253
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
254
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
255
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
256
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
257
|
-
|
258
267
|
# Specify input columns so column pruning will be enforced
|
259
268
|
selected_cols = self._get_active_columns()
|
260
269
|
if len(selected_cols) > 0:
|
@@ -282,7 +291,9 @@ class AffinityPropagation(BaseTransformer):
|
|
282
291
|
label_cols=self.label_cols,
|
283
292
|
sample_weight_col=self.sample_weight_col,
|
284
293
|
autogenerated=self._autogenerated,
|
285
|
-
subproject=_SUBPROJECT
|
294
|
+
subproject=_SUBPROJECT,
|
295
|
+
use_external_memory_version=self._use_external_memory_version,
|
296
|
+
batch_size=self._batch_size,
|
286
297
|
)
|
287
298
|
self._sklearn_object = model_trainer.train()
|
288
299
|
self._is_fitted = True
|
@@ -553,6 +564,22 @@ class AffinityPropagation(BaseTransformer):
|
|
553
564
|
# each row containing a list of values.
|
554
565
|
expected_dtype = "ARRAY"
|
555
566
|
|
567
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
568
|
+
if expected_dtype == "":
|
569
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
570
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
571
|
+
expected_dtype = "ARRAY"
|
572
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
573
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
574
|
+
expected_dtype = "ARRAY"
|
575
|
+
else:
|
576
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
577
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
578
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
579
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
580
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
581
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
582
|
+
|
556
583
|
output_df = self._batch_inference(
|
557
584
|
dataset=dataset,
|
558
585
|
inference_method="transform",
|
@@ -568,8 +595,8 @@ class AffinityPropagation(BaseTransformer):
|
|
568
595
|
|
569
596
|
return output_df
|
570
597
|
|
571
|
-
@available_if(
|
572
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
598
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
599
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
573
600
|
""" Fit clustering from features/affinity matrix; return cluster labels
|
574
601
|
For more details on this function, see [sklearn.cluster.AffinityPropagation.fit_predict]
|
575
602
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AffinityPropagation.html#sklearn.cluster.AffinityPropagation.fit_predict)
|
@@ -584,13 +611,21 @@ class AffinityPropagation(BaseTransformer):
|
|
584
611
|
Returns:
|
585
612
|
Predicted dataset.
|
586
613
|
"""
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
614
|
+
self.fit(dataset)
|
615
|
+
assert self._sklearn_object is not None
|
616
|
+
return self._sklearn_object.labels_
|
617
|
+
|
618
|
+
|
619
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
620
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
621
|
+
"""
|
622
|
+
Returns:
|
623
|
+
Transformed dataset.
|
624
|
+
"""
|
625
|
+
self.fit(dataset)
|
626
|
+
assert self._sklearn_object is not None
|
627
|
+
return self._sklearn_object.embedding_
|
628
|
+
|
594
629
|
|
595
630
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
596
631
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|