snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
- snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
- snowflake/ml/modeling/cluster/birch.py +94 -124
- snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
- snowflake/ml/modeling/cluster/dbscan.py +94 -124
- snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
- snowflake/ml/modeling/cluster/k_means.py +93 -124
- snowflake/ml/modeling/cluster/mean_shift.py +94 -124
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
- snowflake/ml/modeling/cluster/optics.py +94 -124
- snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
- snowflake/ml/modeling/compose/column_transformer.py +94 -124
- snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
- snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
- snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
- snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
- snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
- snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
- snowflake/ml/modeling/covariance/oas.py +80 -110
- snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
- snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
- snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
- snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
- snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/pca.py +94 -124
- snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
- snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
- snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
- snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
- snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
- snowflake/ml/modeling/framework/base.py +2 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
- snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
- snowflake/ml/modeling/impute/knn_imputer.py +94 -124
- snowflake/ml/modeling/impute/missing_indicator.py +94 -124
- snowflake/ml/modeling/impute/simple_imputer.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
- snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
- snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/lars.py +96 -124
- snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
- snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
- snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
- snowflake/ml/modeling/linear_model/perceptron.py +95 -124
- snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ridge.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
- snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
- snowflake/ml/modeling/manifold/isomap.py +94 -124
- snowflake/ml/modeling/manifold/mds.py +94 -124
- snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
- snowflake/ml/modeling/manifold/tsne.py +94 -124
- snowflake/ml/modeling/metrics/classification.py +187 -52
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
- snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
- snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
- snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
- snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
- snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
- snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
- snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
- snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
- snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
- snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
- snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
- snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
- snowflake/ml/modeling/svm/linear_svc.py +96 -124
- snowflake/ml/modeling/svm/linear_svr.py +96 -124
- snowflake/ml/modeling/svm/nu_svc.py +96 -124
- snowflake/ml/modeling/svm/nu_svr.py +96 -124
- snowflake/ml/modeling/svm/svc.py +96 -124
- snowflake/ml/modeling/svm/svr.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
- snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -22,17 +22,19 @@ from sklearn.utils.metaestimators import available_if
|
|
22
22
|
from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
23
23
|
from snowflake.ml._internal import telemetry
|
24
24
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
25
|
+
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
25
26
|
from snowflake.ml._internal.utils import pkg_version_utils, identifier
|
26
|
-
from snowflake.snowpark import DataFrame
|
27
|
+
from snowflake.snowpark import DataFrame, Session
|
27
28
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
28
29
|
from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
|
30
|
+
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
|
+
from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
|
29
32
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
30
33
|
gather_dependencies,
|
31
34
|
original_estimator_has_callable,
|
32
35
|
transform_snowml_obj_to_sklearn_obj,
|
33
36
|
validate_sklearn_args,
|
34
37
|
)
|
35
|
-
from snowflake.ml.modeling._internal.snowpark_handlers import SklearnWrapperProvider
|
36
38
|
from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
|
37
39
|
|
38
40
|
from snowflake.ml.model.model_signature import (
|
@@ -52,7 +54,6 @@ _PROJECT = "ModelDevelopment"
|
|
52
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
|
53
55
|
|
54
56
|
|
55
|
-
|
56
57
|
class AdditiveChi2Sampler(BaseTransformer):
|
57
58
|
r"""Approximate feature map for additive chi2 kernel
|
58
59
|
For more details on this class, see [sklearn.kernel_approximation.AdditiveChi2Sampler]
|
@@ -60,47 +61,54 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
60
61
|
|
61
62
|
Parameters
|
62
63
|
----------
|
63
|
-
sample_steps: int, default=2
|
64
|
-
Gives the number of (complex) sampling points.
|
65
|
-
|
66
|
-
sample_interval: float, default=None
|
67
|
-
Sampling interval. Must be specified when sample_steps not in {1,2,3}.
|
68
64
|
|
69
65
|
input_cols: Optional[Union[str, List[str]]]
|
70
66
|
A string or list of strings representing column names that contain features.
|
71
67
|
If this parameter is not specified, all columns in the input DataFrame except
|
72
68
|
the columns specified by label_cols, sample_weight_col, and passthrough_cols
|
73
|
-
parameters are considered input columns.
|
74
|
-
|
69
|
+
parameters are considered input columns. Input columns can also be set after
|
70
|
+
initialization with the `set_input_cols` method.
|
71
|
+
|
75
72
|
label_cols: Optional[Union[str, List[str]]]
|
76
|
-
|
77
|
-
|
78
|
-
columns. If this parameter is not specified, then object is fitted without
|
79
|
-
labels (like a transformer).
|
80
|
-
|
73
|
+
This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
|
74
|
+
|
81
75
|
output_cols: Optional[Union[str, List[str]]]
|
82
76
|
A string or list of strings representing column names that will store the
|
83
77
|
output of predict and transform operations. The length of output_cols must
|
84
|
-
match the expected number of output columns from the specific
|
78
|
+
match the expected number of output columns from the specific predictor or
|
85
79
|
transformer class used.
|
86
|
-
If this parameter
|
87
|
-
|
88
|
-
|
89
|
-
be set explicitly for transformers.
|
80
|
+
If you omit this parameter, output column names are derived by adding an
|
81
|
+
OUTPUT_ prefix to the label column names for supervised estimators, or
|
82
|
+
OUTPUT_<IDX>for unsupervised estimators. These inferred output column names
|
83
|
+
work for predictors, but output_cols must be set explicitly for transformers.
|
84
|
+
In general, explicitly specifying output column names is clearer, especially
|
85
|
+
if you don’t specify the input column names.
|
86
|
+
To transform in place, pass the same names for input_cols and output_cols.
|
87
|
+
be set explicitly for transformers. Output columns can also be set after
|
88
|
+
initialization with the `set_output_cols` method.
|
90
89
|
|
91
90
|
sample_weight_col: Optional[str]
|
92
91
|
A string representing the column name containing the sample weights.
|
93
|
-
This argument is only required when working with weighted datasets.
|
92
|
+
This argument is only required when working with weighted datasets. Sample
|
93
|
+
weight column can also be set after initialization with the
|
94
|
+
`set_sample_weight_col` method.
|
94
95
|
|
95
96
|
passthrough_cols: Optional[Union[str, List[str]]]
|
96
97
|
A string or a list of strings indicating column names to be excluded from any
|
97
98
|
operations (such as train, transform, or inference). These specified column(s)
|
98
99
|
will remain untouched throughout the process. This option is helpful in scenarios
|
99
100
|
requiring automatic input_cols inference, but need to avoid using specific
|
100
|
-
columns, like index columns, during training or inference.
|
101
|
+
columns, like index columns, during training or inference. Passthrough columns
|
102
|
+
can also be set after initialization with the `set_passthrough_cols` method.
|
101
103
|
|
102
104
|
drop_input_cols: Optional[bool], default=False
|
103
105
|
If set, the response of predict(), transform() methods will not contain input columns.
|
106
|
+
|
107
|
+
sample_steps: int, default=2
|
108
|
+
Gives the number of (complex) sampling points.
|
109
|
+
|
110
|
+
sample_interval: float, default=None
|
111
|
+
Sampling interval. Must be specified when sample_steps not in {1,2,3}.
|
104
112
|
"""
|
105
113
|
|
106
114
|
def __init__( # type: ignore[no-untyped-def]
|
@@ -123,7 +131,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
123
131
|
self.set_passthrough_cols(passthrough_cols)
|
124
132
|
self.set_drop_input_cols(drop_input_cols)
|
125
133
|
self.set_sample_weight_col(sample_weight_col)
|
126
|
-
deps = set(
|
134
|
+
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
127
135
|
|
128
136
|
self._deps = list(deps)
|
129
137
|
|
@@ -133,13 +141,14 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
133
141
|
args=init_args,
|
134
142
|
klass=sklearn.kernel_approximation.AdditiveChi2Sampler
|
135
143
|
)
|
136
|
-
self._sklearn_object = sklearn.kernel_approximation.AdditiveChi2Sampler(
|
144
|
+
self._sklearn_object: Any = sklearn.kernel_approximation.AdditiveChi2Sampler(
|
137
145
|
**cleaned_up_init_args,
|
138
146
|
)
|
139
147
|
self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
|
140
148
|
# If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
|
141
149
|
self._snowpark_cols: Optional[List[str]] = self.input_cols
|
142
|
-
self._handlers: FitPredictHandlers = HandlersImpl(class_name=AdditiveChi2Sampler.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True
|
150
|
+
self._handlers: FitPredictHandlers = HandlersImpl(class_name=AdditiveChi2Sampler.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
|
151
|
+
self._autogenerated = True
|
143
152
|
|
144
153
|
def _get_rand_id(self) -> str:
|
145
154
|
"""
|
@@ -195,54 +204,48 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
195
204
|
self
|
196
205
|
"""
|
197
206
|
self._infer_input_output_cols(dataset)
|
198
|
-
if isinstance(dataset,
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
self.
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
207
|
+
if isinstance(dataset, DataFrame):
|
208
|
+
session = dataset._session
|
209
|
+
assert session is not None # keep mypy happy
|
210
|
+
# Validate that key package version in user workspace are supported in snowflake conda channel
|
211
|
+
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
212
|
+
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
213
|
+
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
214
|
+
|
215
|
+
# Specify input columns so column pruning will be enforced
|
216
|
+
selected_cols = self._get_active_columns()
|
217
|
+
if len(selected_cols) > 0:
|
218
|
+
dataset = dataset.select(selected_cols)
|
219
|
+
|
220
|
+
self._snowpark_cols = dataset.select(self.input_cols).columns
|
221
|
+
|
222
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
223
|
+
if SNOWML_SPROC_ENV in os.environ:
|
224
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
225
|
+
project=_PROJECT,
|
226
|
+
subproject=_SUBPROJECT,
|
227
|
+
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), AdditiveChi2Sampler.__class__.__name__),
|
228
|
+
api_calls=[Session.call],
|
229
|
+
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
230
|
+
)
|
231
|
+
pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
|
232
|
+
pd_df.columns = dataset.columns
|
233
|
+
dataset = pd_df
|
234
|
+
|
235
|
+
model_trainer = ModelTrainerBuilder.build(
|
236
|
+
estimator=self._sklearn_object,
|
237
|
+
dataset=dataset,
|
238
|
+
input_cols=self.input_cols,
|
239
|
+
label_cols=self.label_cols,
|
240
|
+
sample_weight_col=self.sample_weight_col,
|
241
|
+
autogenerated=self._autogenerated,
|
242
|
+
subproject=_SUBPROJECT
|
243
|
+
)
|
244
|
+
self._sklearn_object = model_trainer.train()
|
214
245
|
self._is_fitted = True
|
215
246
|
self._get_model_signatures(dataset)
|
216
247
|
return self
|
217
248
|
|
218
|
-
def _fit_snowpark(self, dataset: DataFrame) -> None:
|
219
|
-
session = dataset._session
|
220
|
-
assert session is not None # keep mypy happy
|
221
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
222
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
223
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
224
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
225
|
-
|
226
|
-
# Specify input columns so column pruning will be enforced
|
227
|
-
selected_cols = self._get_active_columns()
|
228
|
-
if len(selected_cols) > 0:
|
229
|
-
dataset = dataset.select(selected_cols)
|
230
|
-
|
231
|
-
estimator = self._sklearn_object
|
232
|
-
assert estimator is not None # Keep mypy happy
|
233
|
-
|
234
|
-
self._snowpark_cols = dataset.select(self.input_cols).columns
|
235
|
-
|
236
|
-
self._sklearn_object = self._handlers.fit_snowpark(
|
237
|
-
dataset,
|
238
|
-
session,
|
239
|
-
estimator,
|
240
|
-
["snowflake-snowpark-python"] + self._get_dependencies(),
|
241
|
-
self.input_cols,
|
242
|
-
self.label_cols,
|
243
|
-
self.sample_weight_col,
|
244
|
-
)
|
245
|
-
|
246
249
|
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
247
250
|
if self._drop_input_cols:
|
248
251
|
return []
|
@@ -430,11 +433,6 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
430
433
|
subproject=_SUBPROJECT,
|
431
434
|
custom_tags=dict([("autogen", True)]),
|
432
435
|
)
|
433
|
-
@telemetry.add_stmt_params_to_df(
|
434
|
-
project=_PROJECT,
|
435
|
-
subproject=_SUBPROJECT,
|
436
|
-
custom_tags=dict([("autogen", True)]),
|
437
|
-
)
|
438
436
|
def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
439
437
|
"""Method not supported for this class.
|
440
438
|
|
@@ -486,11 +484,6 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
486
484
|
subproject=_SUBPROJECT,
|
487
485
|
custom_tags=dict([("autogen", True)]),
|
488
486
|
)
|
489
|
-
@telemetry.add_stmt_params_to_df(
|
490
|
-
project=_PROJECT,
|
491
|
-
subproject=_SUBPROJECT,
|
492
|
-
custom_tags=dict([("autogen", True)]),
|
493
|
-
)
|
494
487
|
def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
495
488
|
"""Apply approximate feature map to X
|
496
489
|
For more details on this function, see [sklearn.kernel_approximation.AdditiveChi2Sampler.transform]
|
@@ -549,7 +542,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
549
542
|
if False:
|
550
543
|
self.fit(dataset)
|
551
544
|
assert self._sklearn_object is not None
|
552
|
-
|
545
|
+
labels : npt.NDArray[Any] = self._sklearn_object.labels_
|
546
|
+
return labels
|
553
547
|
else:
|
554
548
|
raise NotImplementedError
|
555
549
|
|
@@ -585,6 +579,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
585
579
|
output_cols = []
|
586
580
|
|
587
581
|
# Make sure column names are valid snowflake identifiers.
|
582
|
+
assert output_cols is not None # Make MyPy happy
|
588
583
|
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
589
584
|
|
590
585
|
return rv
|
@@ -595,11 +590,6 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
595
590
|
subproject=_SUBPROJECT,
|
596
591
|
custom_tags=dict([("autogen", True)]),
|
597
592
|
)
|
598
|
-
@telemetry.add_stmt_params_to_df(
|
599
|
-
project=_PROJECT,
|
600
|
-
subproject=_SUBPROJECT,
|
601
|
-
custom_tags=dict([("autogen", True)]),
|
602
|
-
)
|
603
593
|
def predict_proba(
|
604
594
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_"
|
605
595
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -640,11 +630,6 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
640
630
|
subproject=_SUBPROJECT,
|
641
631
|
custom_tags=dict([("autogen", True)]),
|
642
632
|
)
|
643
|
-
@telemetry.add_stmt_params_to_df(
|
644
|
-
project=_PROJECT,
|
645
|
-
subproject=_SUBPROJECT,
|
646
|
-
custom_tags=dict([("autogen", True)]),
|
647
|
-
)
|
648
633
|
def predict_log_proba(
|
649
634
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_"
|
650
635
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -681,16 +666,6 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
681
666
|
return output_df
|
682
667
|
|
683
668
|
@available_if(original_estimator_has_callable("decision_function")) # type: ignore[misc]
|
684
|
-
@telemetry.send_api_usage_telemetry(
|
685
|
-
project=_PROJECT,
|
686
|
-
subproject=_SUBPROJECT,
|
687
|
-
custom_tags=dict([("autogen", True)]),
|
688
|
-
)
|
689
|
-
@telemetry.add_stmt_params_to_df(
|
690
|
-
project=_PROJECT,
|
691
|
-
subproject=_SUBPROJECT,
|
692
|
-
custom_tags=dict([("autogen", True)]),
|
693
|
-
)
|
694
669
|
def decision_function(
|
695
670
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_"
|
696
671
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -789,11 +764,6 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
789
764
|
subproject=_SUBPROJECT,
|
790
765
|
custom_tags=dict([("autogen", True)]),
|
791
766
|
)
|
792
|
-
@telemetry.add_stmt_params_to_df(
|
793
|
-
project=_PROJECT,
|
794
|
-
subproject=_SUBPROJECT,
|
795
|
-
custom_tags=dict([("autogen", True)]),
|
796
|
-
)
|
797
767
|
def kneighbors(
|
798
768
|
self,
|
799
769
|
dataset: Union[DataFrame, pd.DataFrame],
|
@@ -853,9 +823,9 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
853
823
|
# For classifier, the type of predict is the same as the type of label
|
854
824
|
if self._sklearn_object._estimator_type == 'classifier':
|
855
825
|
# label columns is the desired type for output
|
856
|
-
outputs = _infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True)
|
826
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
857
827
|
# rename the output columns
|
858
|
-
outputs = model_signature_utils.rename_features(outputs, self.output_cols)
|
828
|
+
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
859
829
|
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
860
830
|
([] if self._drop_input_cols else inputs)
|
861
831
|
+ outputs)
|
@@ -22,17 +22,19 @@ from sklearn.utils.metaestimators import available_if
|
|
22
22
|
from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
23
23
|
from snowflake.ml._internal import telemetry
|
24
24
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
25
|
+
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
25
26
|
from snowflake.ml._internal.utils import pkg_version_utils, identifier
|
26
|
-
from snowflake.snowpark import DataFrame
|
27
|
+
from snowflake.snowpark import DataFrame, Session
|
27
28
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
28
29
|
from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
|
30
|
+
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
|
+
from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
|
29
32
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
30
33
|
gather_dependencies,
|
31
34
|
original_estimator_has_callable,
|
32
35
|
transform_snowml_obj_to_sklearn_obj,
|
33
36
|
validate_sklearn_args,
|
34
37
|
)
|
35
|
-
from snowflake.ml.modeling._internal.snowpark_handlers import SklearnWrapperProvider
|
36
38
|
from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
|
37
39
|
|
38
40
|
from snowflake.ml.model.model_signature import (
|
@@ -52,7 +54,6 @@ _PROJECT = "ModelDevelopment"
|
|
52
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
|
53
55
|
|
54
56
|
|
55
|
-
|
56
57
|
class Nystroem(BaseTransformer):
|
57
58
|
r"""Approximate a kernel map using a subset of the training data
|
58
59
|
For more details on this class, see [sklearn.kernel_approximation.Nystroem]
|
@@ -60,6 +61,49 @@ class Nystroem(BaseTransformer):
|
|
60
61
|
|
61
62
|
Parameters
|
62
63
|
----------
|
64
|
+
|
65
|
+
input_cols: Optional[Union[str, List[str]]]
|
66
|
+
A string or list of strings representing column names that contain features.
|
67
|
+
If this parameter is not specified, all columns in the input DataFrame except
|
68
|
+
the columns specified by label_cols, sample_weight_col, and passthrough_cols
|
69
|
+
parameters are considered input columns. Input columns can also be set after
|
70
|
+
initialization with the `set_input_cols` method.
|
71
|
+
|
72
|
+
label_cols: Optional[Union[str, List[str]]]
|
73
|
+
This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
|
74
|
+
|
75
|
+
output_cols: Optional[Union[str, List[str]]]
|
76
|
+
A string or list of strings representing column names that will store the
|
77
|
+
output of predict and transform operations. The length of output_cols must
|
78
|
+
match the expected number of output columns from the specific predictor or
|
79
|
+
transformer class used.
|
80
|
+
If you omit this parameter, output column names are derived by adding an
|
81
|
+
OUTPUT_ prefix to the label column names for supervised estimators, or
|
82
|
+
OUTPUT_<IDX>for unsupervised estimators. These inferred output column names
|
83
|
+
work for predictors, but output_cols must be set explicitly for transformers.
|
84
|
+
In general, explicitly specifying output column names is clearer, especially
|
85
|
+
if you don’t specify the input column names.
|
86
|
+
To transform in place, pass the same names for input_cols and output_cols.
|
87
|
+
be set explicitly for transformers. Output columns can also be set after
|
88
|
+
initialization with the `set_output_cols` method.
|
89
|
+
|
90
|
+
sample_weight_col: Optional[str]
|
91
|
+
A string representing the column name containing the sample weights.
|
92
|
+
This argument is only required when working with weighted datasets. Sample
|
93
|
+
weight column can also be set after initialization with the
|
94
|
+
`set_sample_weight_col` method.
|
95
|
+
|
96
|
+
passthrough_cols: Optional[Union[str, List[str]]]
|
97
|
+
A string or a list of strings indicating column names to be excluded from any
|
98
|
+
operations (such as train, transform, or inference). These specified column(s)
|
99
|
+
will remain untouched throughout the process. This option is helpful in scenarios
|
100
|
+
requiring automatic input_cols inference, but need to avoid using specific
|
101
|
+
columns, like index columns, during training or inference. Passthrough columns
|
102
|
+
can also be set after initialization with the `set_passthrough_cols` method.
|
103
|
+
|
104
|
+
drop_input_cols: Optional[bool], default=False
|
105
|
+
If set, the response of predict(), transform() methods will not contain input columns.
|
106
|
+
|
63
107
|
kernel: str or callable, default='rbf'
|
64
108
|
Kernel map to be approximated. A callable should accept two arguments
|
65
109
|
and the keyword arguments passed to this object as `kernel_params`, and
|
@@ -101,42 +145,6 @@ class Nystroem(BaseTransformer):
|
|
101
145
|
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
|
102
146
|
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
|
103
147
|
for more details.
|
104
|
-
|
105
|
-
input_cols: Optional[Union[str, List[str]]]
|
106
|
-
A string or list of strings representing column names that contain features.
|
107
|
-
If this parameter is not specified, all columns in the input DataFrame except
|
108
|
-
the columns specified by label_cols, sample_weight_col, and passthrough_cols
|
109
|
-
parameters are considered input columns.
|
110
|
-
|
111
|
-
label_cols: Optional[Union[str, List[str]]]
|
112
|
-
A string or list of strings representing column names that contain labels.
|
113
|
-
This is a required param for estimators, as there is no way to infer these
|
114
|
-
columns. If this parameter is not specified, then object is fitted without
|
115
|
-
labels (like a transformer).
|
116
|
-
|
117
|
-
output_cols: Optional[Union[str, List[str]]]
|
118
|
-
A string or list of strings representing column names that will store the
|
119
|
-
output of predict and transform operations. The length of output_cols must
|
120
|
-
match the expected number of output columns from the specific estimator or
|
121
|
-
transformer class used.
|
122
|
-
If this parameter is not specified, output column names are derived by
|
123
|
-
adding an OUTPUT_ prefix to the label column names. These inferred output
|
124
|
-
column names work for estimator's predict() method, but output_cols must
|
125
|
-
be set explicitly for transformers.
|
126
|
-
|
127
|
-
sample_weight_col: Optional[str]
|
128
|
-
A string representing the column name containing the sample weights.
|
129
|
-
This argument is only required when working with weighted datasets.
|
130
|
-
|
131
|
-
passthrough_cols: Optional[Union[str, List[str]]]
|
132
|
-
A string or a list of strings indicating column names to be excluded from any
|
133
|
-
operations (such as train, transform, or inference). These specified column(s)
|
134
|
-
will remain untouched throughout the process. This option is helpful in scenarios
|
135
|
-
requiring automatic input_cols inference, but need to avoid using specific
|
136
|
-
columns, like index columns, during training or inference.
|
137
|
-
|
138
|
-
drop_input_cols: Optional[bool], default=False
|
139
|
-
If set, the response of predict(), transform() methods will not contain input columns.
|
140
148
|
"""
|
141
149
|
|
142
150
|
def __init__( # type: ignore[no-untyped-def]
|
@@ -165,7 +173,7 @@ class Nystroem(BaseTransformer):
|
|
165
173
|
self.set_passthrough_cols(passthrough_cols)
|
166
174
|
self.set_drop_input_cols(drop_input_cols)
|
167
175
|
self.set_sample_weight_col(sample_weight_col)
|
168
|
-
deps = set(
|
176
|
+
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
169
177
|
|
170
178
|
self._deps = list(deps)
|
171
179
|
|
@@ -181,13 +189,14 @@ class Nystroem(BaseTransformer):
|
|
181
189
|
args=init_args,
|
182
190
|
klass=sklearn.kernel_approximation.Nystroem
|
183
191
|
)
|
184
|
-
self._sklearn_object = sklearn.kernel_approximation.Nystroem(
|
192
|
+
self._sklearn_object: Any = sklearn.kernel_approximation.Nystroem(
|
185
193
|
**cleaned_up_init_args,
|
186
194
|
)
|
187
195
|
self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
|
188
196
|
# If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
|
189
197
|
self._snowpark_cols: Optional[List[str]] = self.input_cols
|
190
|
-
self._handlers: FitPredictHandlers = HandlersImpl(class_name=Nystroem.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True
|
198
|
+
self._handlers: FitPredictHandlers = HandlersImpl(class_name=Nystroem.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
|
199
|
+
self._autogenerated = True
|
191
200
|
|
192
201
|
def _get_rand_id(self) -> str:
|
193
202
|
"""
|
@@ -243,54 +252,48 @@ class Nystroem(BaseTransformer):
|
|
243
252
|
self
|
244
253
|
"""
|
245
254
|
self._infer_input_output_cols(dataset)
|
246
|
-
if isinstance(dataset,
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
self.
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
255
|
+
if isinstance(dataset, DataFrame):
|
256
|
+
session = dataset._session
|
257
|
+
assert session is not None # keep mypy happy
|
258
|
+
# Validate that key package version in user workspace are supported in snowflake conda channel
|
259
|
+
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
260
|
+
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
261
|
+
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
262
|
+
|
263
|
+
# Specify input columns so column pruning will be enforced
|
264
|
+
selected_cols = self._get_active_columns()
|
265
|
+
if len(selected_cols) > 0:
|
266
|
+
dataset = dataset.select(selected_cols)
|
267
|
+
|
268
|
+
self._snowpark_cols = dataset.select(self.input_cols).columns
|
269
|
+
|
270
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
271
|
+
if SNOWML_SPROC_ENV in os.environ:
|
272
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
273
|
+
project=_PROJECT,
|
274
|
+
subproject=_SUBPROJECT,
|
275
|
+
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), Nystroem.__class__.__name__),
|
276
|
+
api_calls=[Session.call],
|
277
|
+
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
278
|
+
)
|
279
|
+
pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
|
280
|
+
pd_df.columns = dataset.columns
|
281
|
+
dataset = pd_df
|
282
|
+
|
283
|
+
model_trainer = ModelTrainerBuilder.build(
|
284
|
+
estimator=self._sklearn_object,
|
285
|
+
dataset=dataset,
|
286
|
+
input_cols=self.input_cols,
|
287
|
+
label_cols=self.label_cols,
|
288
|
+
sample_weight_col=self.sample_weight_col,
|
289
|
+
autogenerated=self._autogenerated,
|
290
|
+
subproject=_SUBPROJECT
|
291
|
+
)
|
292
|
+
self._sklearn_object = model_trainer.train()
|
262
293
|
self._is_fitted = True
|
263
294
|
self._get_model_signatures(dataset)
|
264
295
|
return self
|
265
296
|
|
266
|
-
def _fit_snowpark(self, dataset: DataFrame) -> None:
|
267
|
-
session = dataset._session
|
268
|
-
assert session is not None # keep mypy happy
|
269
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
270
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
271
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
272
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
273
|
-
|
274
|
-
# Specify input columns so column pruning will be enforced
|
275
|
-
selected_cols = self._get_active_columns()
|
276
|
-
if len(selected_cols) > 0:
|
277
|
-
dataset = dataset.select(selected_cols)
|
278
|
-
|
279
|
-
estimator = self._sklearn_object
|
280
|
-
assert estimator is not None # Keep mypy happy
|
281
|
-
|
282
|
-
self._snowpark_cols = dataset.select(self.input_cols).columns
|
283
|
-
|
284
|
-
self._sklearn_object = self._handlers.fit_snowpark(
|
285
|
-
dataset,
|
286
|
-
session,
|
287
|
-
estimator,
|
288
|
-
["snowflake-snowpark-python"] + self._get_dependencies(),
|
289
|
-
self.input_cols,
|
290
|
-
self.label_cols,
|
291
|
-
self.sample_weight_col,
|
292
|
-
)
|
293
|
-
|
294
297
|
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
295
298
|
if self._drop_input_cols:
|
296
299
|
return []
|
@@ -478,11 +481,6 @@ class Nystroem(BaseTransformer):
|
|
478
481
|
subproject=_SUBPROJECT,
|
479
482
|
custom_tags=dict([("autogen", True)]),
|
480
483
|
)
|
481
|
-
@telemetry.add_stmt_params_to_df(
|
482
|
-
project=_PROJECT,
|
483
|
-
subproject=_SUBPROJECT,
|
484
|
-
custom_tags=dict([("autogen", True)]),
|
485
|
-
)
|
486
484
|
def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
487
485
|
"""Method not supported for this class.
|
488
486
|
|
@@ -534,11 +532,6 @@ class Nystroem(BaseTransformer):
|
|
534
532
|
subproject=_SUBPROJECT,
|
535
533
|
custom_tags=dict([("autogen", True)]),
|
536
534
|
)
|
537
|
-
@telemetry.add_stmt_params_to_df(
|
538
|
-
project=_PROJECT,
|
539
|
-
subproject=_SUBPROJECT,
|
540
|
-
custom_tags=dict([("autogen", True)]),
|
541
|
-
)
|
542
535
|
def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
543
536
|
"""Apply feature map to X
|
544
537
|
For more details on this function, see [sklearn.kernel_approximation.Nystroem.transform]
|
@@ -597,7 +590,8 @@ class Nystroem(BaseTransformer):
|
|
597
590
|
if False:
|
598
591
|
self.fit(dataset)
|
599
592
|
assert self._sklearn_object is not None
|
600
|
-
|
593
|
+
labels : npt.NDArray[Any] = self._sklearn_object.labels_
|
594
|
+
return labels
|
601
595
|
else:
|
602
596
|
raise NotImplementedError
|
603
597
|
|
@@ -633,6 +627,7 @@ class Nystroem(BaseTransformer):
|
|
633
627
|
output_cols = []
|
634
628
|
|
635
629
|
# Make sure column names are valid snowflake identifiers.
|
630
|
+
assert output_cols is not None # Make MyPy happy
|
636
631
|
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
637
632
|
|
638
633
|
return rv
|
@@ -643,11 +638,6 @@ class Nystroem(BaseTransformer):
|
|
643
638
|
subproject=_SUBPROJECT,
|
644
639
|
custom_tags=dict([("autogen", True)]),
|
645
640
|
)
|
646
|
-
@telemetry.add_stmt_params_to_df(
|
647
|
-
project=_PROJECT,
|
648
|
-
subproject=_SUBPROJECT,
|
649
|
-
custom_tags=dict([("autogen", True)]),
|
650
|
-
)
|
651
641
|
def predict_proba(
|
652
642
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_"
|
653
643
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -688,11 +678,6 @@ class Nystroem(BaseTransformer):
|
|
688
678
|
subproject=_SUBPROJECT,
|
689
679
|
custom_tags=dict([("autogen", True)]),
|
690
680
|
)
|
691
|
-
@telemetry.add_stmt_params_to_df(
|
692
|
-
project=_PROJECT,
|
693
|
-
subproject=_SUBPROJECT,
|
694
|
-
custom_tags=dict([("autogen", True)]),
|
695
|
-
)
|
696
681
|
def predict_log_proba(
|
697
682
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_"
|
698
683
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -729,16 +714,6 @@ class Nystroem(BaseTransformer):
|
|
729
714
|
return output_df
|
730
715
|
|
731
716
|
@available_if(original_estimator_has_callable("decision_function")) # type: ignore[misc]
|
732
|
-
@telemetry.send_api_usage_telemetry(
|
733
|
-
project=_PROJECT,
|
734
|
-
subproject=_SUBPROJECT,
|
735
|
-
custom_tags=dict([("autogen", True)]),
|
736
|
-
)
|
737
|
-
@telemetry.add_stmt_params_to_df(
|
738
|
-
project=_PROJECT,
|
739
|
-
subproject=_SUBPROJECT,
|
740
|
-
custom_tags=dict([("autogen", True)]),
|
741
|
-
)
|
742
717
|
def decision_function(
|
743
718
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_"
|
744
719
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -837,11 +812,6 @@ class Nystroem(BaseTransformer):
|
|
837
812
|
subproject=_SUBPROJECT,
|
838
813
|
custom_tags=dict([("autogen", True)]),
|
839
814
|
)
|
840
|
-
@telemetry.add_stmt_params_to_df(
|
841
|
-
project=_PROJECT,
|
842
|
-
subproject=_SUBPROJECT,
|
843
|
-
custom_tags=dict([("autogen", True)]),
|
844
|
-
)
|
845
815
|
def kneighbors(
|
846
816
|
self,
|
847
817
|
dataset: Union[DataFrame, pd.DataFrame],
|
@@ -901,9 +871,9 @@ class Nystroem(BaseTransformer):
|
|
901
871
|
# For classifier, the type of predict is the same as the type of label
|
902
872
|
if self._sklearn_object._estimator_type == 'classifier':
|
903
873
|
# label columns is the desired type for output
|
904
|
-
outputs = _infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True)
|
874
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
905
875
|
# rename the output columns
|
906
|
-
outputs = model_signature_utils.rename_features(outputs, self.output_cols)
|
876
|
+
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
907
877
|
self._model_signature_dict["predict"] = ModelSignature(inputs,
|
908
878
|
([] if self._drop_input_cols else inputs)
|
909
879
|
+ outputs)
|