snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class VarianceThreshold(BaseTransformer):
|
71
64
|
r"""Feature selector that removes all low-variance features
|
72
65
|
For more details on this class, see [sklearn.feature_selection.VarianceThreshold]
|
@@ -196,12 +189,7 @@ class VarianceThreshold(BaseTransformer):
|
|
196
189
|
)
|
197
190
|
return selected_cols
|
198
191
|
|
199
|
-
|
200
|
-
project=_PROJECT,
|
201
|
-
subproject=_SUBPROJECT,
|
202
|
-
custom_tags=dict([("autogen", True)]),
|
203
|
-
)
|
204
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "VarianceThreshold":
|
192
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "VarianceThreshold":
|
205
193
|
"""Learn empirical variances from X
|
206
194
|
For more details on this function, see [sklearn.feature_selection.VarianceThreshold.fit]
|
207
195
|
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.VarianceThreshold.html#sklearn.feature_selection.VarianceThreshold.fit)
|
@@ -228,12 +216,14 @@ class VarianceThreshold(BaseTransformer):
|
|
228
216
|
|
229
217
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
230
218
|
|
231
|
-
|
219
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
232
220
|
if SNOWML_SPROC_ENV in os.environ:
|
233
221
|
statement_params = telemetry.get_function_usage_statement_params(
|
234
222
|
project=_PROJECT,
|
235
223
|
subproject=_SUBPROJECT,
|
236
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
224
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
225
|
+
inspect.currentframe(), VarianceThreshold.__class__.__name__
|
226
|
+
),
|
237
227
|
api_calls=[Session.call],
|
238
228
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
239
229
|
)
|
@@ -254,27 +244,24 @@ class VarianceThreshold(BaseTransformer):
|
|
254
244
|
)
|
255
245
|
self._sklearn_object = model_trainer.train()
|
256
246
|
self._is_fitted = True
|
257
|
-
self.
|
247
|
+
self._generate_model_signatures(dataset)
|
258
248
|
return self
|
259
249
|
|
260
250
|
def _batch_inference_validate_snowpark(
|
261
251
|
self,
|
262
252
|
dataset: DataFrame,
|
263
253
|
inference_method: str,
|
264
|
-
) ->
|
265
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
266
|
-
return the available package that exists in the snowflake anaconda channel
|
254
|
+
) -> None:
|
255
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
267
256
|
|
268
257
|
Args:
|
269
258
|
dataset: snowpark dataframe
|
270
259
|
inference_method: the inference method such as predict, score...
|
271
|
-
|
260
|
+
|
272
261
|
Raises:
|
273
262
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
274
263
|
SnowflakeMLException: If the session is None, raise error
|
275
264
|
|
276
|
-
Returns:
|
277
|
-
A list of available package that exists in the snowflake anaconda channel
|
278
265
|
"""
|
279
266
|
if not self._is_fitted:
|
280
267
|
raise exceptions.SnowflakeMLException(
|
@@ -292,9 +279,7 @@ class VarianceThreshold(BaseTransformer):
|
|
292
279
|
"Session must not specified for snowpark dataset."
|
293
280
|
),
|
294
281
|
)
|
295
|
-
|
296
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
297
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
282
|
+
|
298
283
|
|
299
284
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
300
285
|
@telemetry.send_api_usage_telemetry(
|
@@ -328,7 +313,9 @@ class VarianceThreshold(BaseTransformer):
|
|
328
313
|
# when it is classifier, infer the datatype from label columns
|
329
314
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
330
315
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
331
|
-
label_cols_signatures = [
|
316
|
+
label_cols_signatures = [
|
317
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
318
|
+
]
|
332
319
|
if len(label_cols_signatures) == 0:
|
333
320
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
334
321
|
raise exceptions.SnowflakeMLException(
|
@@ -336,25 +323,23 @@ class VarianceThreshold(BaseTransformer):
|
|
336
323
|
original_exception=ValueError(error_str),
|
337
324
|
)
|
338
325
|
|
339
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
340
|
-
label_cols_signatures[0].as_snowpark_type()
|
341
|
-
)
|
326
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
342
327
|
|
343
|
-
self.
|
344
|
-
|
328
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
329
|
+
self._deps = self._get_dependencies()
|
330
|
+
assert isinstance(
|
331
|
+
dataset._session, Session
|
332
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
345
333
|
|
346
334
|
transform_kwargs = dict(
|
347
|
-
session
|
348
|
-
dependencies
|
349
|
-
drop_input_cols
|
350
|
-
expected_output_cols_type
|
335
|
+
session=dataset._session,
|
336
|
+
dependencies=self._deps,
|
337
|
+
drop_input_cols=self._drop_input_cols,
|
338
|
+
expected_output_cols_type=expected_type_inferred,
|
351
339
|
)
|
352
340
|
|
353
341
|
elif isinstance(dataset, pd.DataFrame):
|
354
|
-
transform_kwargs = dict(
|
355
|
-
snowpark_input_cols = self._snowpark_cols,
|
356
|
-
drop_input_cols = self._drop_input_cols
|
357
|
-
)
|
342
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
358
343
|
|
359
344
|
transform_handlers = ModelTransformerBuilder.build(
|
360
345
|
dataset=dataset,
|
@@ -396,7 +381,7 @@ class VarianceThreshold(BaseTransformer):
|
|
396
381
|
Transformed dataset.
|
397
382
|
"""
|
398
383
|
super()._check_dataset_type(dataset)
|
399
|
-
inference_method="transform"
|
384
|
+
inference_method = "transform"
|
400
385
|
|
401
386
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
402
387
|
# are specific to the type of dataset used.
|
@@ -426,24 +411,19 @@ class VarianceThreshold(BaseTransformer):
|
|
426
411
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
427
412
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
428
413
|
|
429
|
-
self.
|
430
|
-
|
431
|
-
inference_method=inference_method,
|
432
|
-
)
|
414
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
415
|
+
self._deps = self._get_dependencies()
|
433
416
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
434
417
|
|
435
418
|
transform_kwargs = dict(
|
436
|
-
session
|
437
|
-
dependencies
|
438
|
-
drop_input_cols
|
439
|
-
expected_output_cols_type
|
419
|
+
session=dataset._session,
|
420
|
+
dependencies=self._deps,
|
421
|
+
drop_input_cols=self._drop_input_cols,
|
422
|
+
expected_output_cols_type=expected_dtype,
|
440
423
|
)
|
441
424
|
|
442
425
|
elif isinstance(dataset, pd.DataFrame):
|
443
|
-
transform_kwargs = dict(
|
444
|
-
snowpark_input_cols = self._snowpark_cols,
|
445
|
-
drop_input_cols = self._drop_input_cols
|
446
|
-
)
|
426
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
447
427
|
|
448
428
|
transform_handlers = ModelTransformerBuilder.build(
|
449
429
|
dataset=dataset,
|
@@ -462,7 +442,11 @@ class VarianceThreshold(BaseTransformer):
|
|
462
442
|
return output_df
|
463
443
|
|
464
444
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
465
|
-
def fit_predict(
|
445
|
+
def fit_predict(
|
446
|
+
self,
|
447
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
448
|
+
output_cols_prefix: str = "fit_predict_",
|
449
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
466
450
|
""" Method not supported for this class.
|
467
451
|
|
468
452
|
|
@@ -487,22 +471,106 @@ class VarianceThreshold(BaseTransformer):
|
|
487
471
|
)
|
488
472
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
489
473
|
drop_input_cols=self._drop_input_cols,
|
490
|
-
expected_output_cols_list=
|
474
|
+
expected_output_cols_list=(
|
475
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
476
|
+
),
|
491
477
|
)
|
492
478
|
self._sklearn_object = fitted_estimator
|
493
479
|
self._is_fitted = True
|
494
480
|
return output_result
|
495
481
|
|
482
|
+
|
483
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
484
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
485
|
+
""" Fit to data, then transform it
|
486
|
+
For more details on this function, see [sklearn.feature_selection.VarianceThreshold.fit_transform]
|
487
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.VarianceThreshold.html#sklearn.feature_selection.VarianceThreshold.fit_transform)
|
488
|
+
|
489
|
+
|
490
|
+
Raises:
|
491
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
496
492
|
|
497
|
-
|
498
|
-
|
499
|
-
|
493
|
+
Args:
|
494
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
495
|
+
Snowpark or Pandas DataFrame.
|
496
|
+
output_cols_prefix: Prefix for the response columns
|
500
497
|
Returns:
|
501
498
|
Transformed dataset.
|
502
499
|
"""
|
503
|
-
self.
|
504
|
-
|
505
|
-
|
500
|
+
self._infer_input_output_cols(dataset)
|
501
|
+
super()._check_dataset_type(dataset)
|
502
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
503
|
+
estimator=self._sklearn_object,
|
504
|
+
dataset=dataset,
|
505
|
+
input_cols=self.input_cols,
|
506
|
+
label_cols=self.label_cols,
|
507
|
+
sample_weight_col=self.sample_weight_col,
|
508
|
+
autogenerated=self._autogenerated,
|
509
|
+
subproject=_SUBPROJECT,
|
510
|
+
)
|
511
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
512
|
+
drop_input_cols=self._drop_input_cols,
|
513
|
+
expected_output_cols_list=self.output_cols,
|
514
|
+
)
|
515
|
+
self._sklearn_object = fitted_estimator
|
516
|
+
self._is_fitted = True
|
517
|
+
return output_result
|
518
|
+
|
519
|
+
|
520
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
521
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
522
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
523
|
+
"""
|
524
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
525
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
526
|
+
if output_cols:
|
527
|
+
output_cols = [
|
528
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
529
|
+
for c in output_cols
|
530
|
+
]
|
531
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
532
|
+
output_cols = [output_cols_prefix]
|
533
|
+
elif self._sklearn_object is not None:
|
534
|
+
classes = self._sklearn_object.classes_
|
535
|
+
if isinstance(classes, numpy.ndarray):
|
536
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
537
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
538
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
539
|
+
output_cols = []
|
540
|
+
for i, cl in enumerate(classes):
|
541
|
+
# For binary classification, there is only one output column for each class
|
542
|
+
# ndarray as the two classes are complementary.
|
543
|
+
if len(cl) == 2:
|
544
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
545
|
+
else:
|
546
|
+
output_cols.extend([
|
547
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
548
|
+
])
|
549
|
+
else:
|
550
|
+
output_cols = []
|
551
|
+
|
552
|
+
# Make sure column names are valid snowflake identifiers.
|
553
|
+
assert output_cols is not None # Make MyPy happy
|
554
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
555
|
+
|
556
|
+
return rv
|
557
|
+
|
558
|
+
def _align_expected_output_names(
|
559
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
560
|
+
) -> List[str]:
|
561
|
+
# in case the inferred output column names dimension is different
|
562
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
563
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
564
|
+
output_df_columns = list(output_df_pd.columns)
|
565
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
566
|
+
if self.sample_weight_col:
|
567
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
568
|
+
# if the dimension of inferred output column names is correct; use it
|
569
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
570
|
+
return expected_output_cols_list
|
571
|
+
# otherwise, use the sklearn estimator's output
|
572
|
+
else:
|
573
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
506
574
|
|
507
575
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
508
576
|
@telemetry.send_api_usage_telemetry(
|
@@ -534,24 +602,26 @@ class VarianceThreshold(BaseTransformer):
|
|
534
602
|
# are specific to the type of dataset used.
|
535
603
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
536
604
|
|
605
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
606
|
+
|
537
607
|
if isinstance(dataset, DataFrame):
|
538
|
-
self.
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
608
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
609
|
+
self._deps = self._get_dependencies()
|
610
|
+
assert isinstance(
|
611
|
+
dataset._session, Session
|
612
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
543
613
|
transform_kwargs = dict(
|
544
614
|
session=dataset._session,
|
545
615
|
dependencies=self._deps,
|
546
|
-
drop_input_cols
|
616
|
+
drop_input_cols=self._drop_input_cols,
|
547
617
|
expected_output_cols_type="float",
|
548
618
|
)
|
619
|
+
expected_output_cols = self._align_expected_output_names(
|
620
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
621
|
+
)
|
549
622
|
|
550
623
|
elif isinstance(dataset, pd.DataFrame):
|
551
|
-
transform_kwargs = dict(
|
552
|
-
snowpark_input_cols = self._snowpark_cols,
|
553
|
-
drop_input_cols = self._drop_input_cols
|
554
|
-
)
|
624
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
555
625
|
|
556
626
|
transform_handlers = ModelTransformerBuilder.build(
|
557
627
|
dataset=dataset,
|
@@ -563,7 +633,7 @@ class VarianceThreshold(BaseTransformer):
|
|
563
633
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
564
634
|
inference_method=inference_method,
|
565
635
|
input_cols=self.input_cols,
|
566
|
-
expected_output_cols=
|
636
|
+
expected_output_cols=expected_output_cols,
|
567
637
|
**transform_kwargs
|
568
638
|
)
|
569
639
|
return output_df
|
@@ -593,29 +663,30 @@ class VarianceThreshold(BaseTransformer):
|
|
593
663
|
Output dataset with log probability of the sample for each class in the model.
|
594
664
|
"""
|
595
665
|
super()._check_dataset_type(dataset)
|
596
|
-
inference_method="predict_log_proba"
|
666
|
+
inference_method = "predict_log_proba"
|
667
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
597
668
|
|
598
669
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
599
670
|
# are specific to the type of dataset used.
|
600
671
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
601
672
|
|
602
673
|
if isinstance(dataset, DataFrame):
|
603
|
-
self.
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
674
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
675
|
+
self._deps = self._get_dependencies()
|
676
|
+
assert isinstance(
|
677
|
+
dataset._session, Session
|
678
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
608
679
|
transform_kwargs = dict(
|
609
680
|
session=dataset._session,
|
610
681
|
dependencies=self._deps,
|
611
|
-
drop_input_cols
|
682
|
+
drop_input_cols=self._drop_input_cols,
|
612
683
|
expected_output_cols_type="float",
|
613
684
|
)
|
685
|
+
expected_output_cols = self._align_expected_output_names(
|
686
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
687
|
+
)
|
614
688
|
elif isinstance(dataset, pd.DataFrame):
|
615
|
-
transform_kwargs = dict(
|
616
|
-
snowpark_input_cols = self._snowpark_cols,
|
617
|
-
drop_input_cols = self._drop_input_cols
|
618
|
-
)
|
689
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
619
690
|
|
620
691
|
transform_handlers = ModelTransformerBuilder.build(
|
621
692
|
dataset=dataset,
|
@@ -628,7 +699,7 @@ class VarianceThreshold(BaseTransformer):
|
|
628
699
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
629
700
|
inference_method=inference_method,
|
630
701
|
input_cols=self.input_cols,
|
631
|
-
expected_output_cols=
|
702
|
+
expected_output_cols=expected_output_cols,
|
632
703
|
**transform_kwargs
|
633
704
|
)
|
634
705
|
return output_df
|
@@ -654,30 +725,32 @@ class VarianceThreshold(BaseTransformer):
|
|
654
725
|
Output dataset with results of the decision function for the samples in input dataset.
|
655
726
|
"""
|
656
727
|
super()._check_dataset_type(dataset)
|
657
|
-
inference_method="decision_function"
|
728
|
+
inference_method = "decision_function"
|
658
729
|
|
659
730
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
660
731
|
# are specific to the type of dataset used.
|
661
732
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
662
733
|
|
734
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
735
|
+
|
663
736
|
if isinstance(dataset, DataFrame):
|
664
|
-
self.
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
737
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
738
|
+
self._deps = self._get_dependencies()
|
739
|
+
assert isinstance(
|
740
|
+
dataset._session, Session
|
741
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
669
742
|
transform_kwargs = dict(
|
670
743
|
session=dataset._session,
|
671
744
|
dependencies=self._deps,
|
672
|
-
drop_input_cols
|
745
|
+
drop_input_cols=self._drop_input_cols,
|
673
746
|
expected_output_cols_type="float",
|
674
747
|
)
|
748
|
+
expected_output_cols = self._align_expected_output_names(
|
749
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
750
|
+
)
|
675
751
|
|
676
752
|
elif isinstance(dataset, pd.DataFrame):
|
677
|
-
transform_kwargs = dict(
|
678
|
-
snowpark_input_cols = self._snowpark_cols,
|
679
|
-
drop_input_cols = self._drop_input_cols
|
680
|
-
)
|
753
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
681
754
|
|
682
755
|
transform_handlers = ModelTransformerBuilder.build(
|
683
756
|
dataset=dataset,
|
@@ -690,7 +763,7 @@ class VarianceThreshold(BaseTransformer):
|
|
690
763
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
691
764
|
inference_method=inference_method,
|
692
765
|
input_cols=self.input_cols,
|
693
|
-
expected_output_cols=
|
766
|
+
expected_output_cols=expected_output_cols,
|
694
767
|
**transform_kwargs
|
695
768
|
)
|
696
769
|
return output_df
|
@@ -719,17 +792,17 @@ class VarianceThreshold(BaseTransformer):
|
|
719
792
|
Output dataset with probability of the sample for each class in the model.
|
720
793
|
"""
|
721
794
|
super()._check_dataset_type(dataset)
|
722
|
-
inference_method="score_samples"
|
795
|
+
inference_method = "score_samples"
|
723
796
|
|
724
797
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
725
798
|
# are specific to the type of dataset used.
|
726
799
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
727
800
|
|
801
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
802
|
+
|
728
803
|
if isinstance(dataset, DataFrame):
|
729
|
-
self.
|
730
|
-
|
731
|
-
inference_method=inference_method,
|
732
|
-
)
|
804
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
805
|
+
self._deps = self._get_dependencies()
|
733
806
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
734
807
|
transform_kwargs = dict(
|
735
808
|
session=dataset._session,
|
@@ -737,6 +810,9 @@ class VarianceThreshold(BaseTransformer):
|
|
737
810
|
drop_input_cols = self._drop_input_cols,
|
738
811
|
expected_output_cols_type="float",
|
739
812
|
)
|
813
|
+
expected_output_cols = self._align_expected_output_names(
|
814
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
815
|
+
)
|
740
816
|
|
741
817
|
elif isinstance(dataset, pd.DataFrame):
|
742
818
|
transform_kwargs = dict(
|
@@ -755,7 +831,7 @@ class VarianceThreshold(BaseTransformer):
|
|
755
831
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
756
832
|
inference_method=inference_method,
|
757
833
|
input_cols=self.input_cols,
|
758
|
-
expected_output_cols=
|
834
|
+
expected_output_cols=expected_output_cols,
|
759
835
|
**transform_kwargs
|
760
836
|
)
|
761
837
|
return output_df
|
@@ -788,17 +864,15 @@ class VarianceThreshold(BaseTransformer):
|
|
788
864
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
789
865
|
|
790
866
|
if isinstance(dataset, DataFrame):
|
791
|
-
self.
|
792
|
-
|
793
|
-
inference_method="score",
|
794
|
-
)
|
867
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
868
|
+
self._deps = self._get_dependencies()
|
795
869
|
selected_cols = self._get_active_columns()
|
796
870
|
if len(selected_cols) > 0:
|
797
871
|
dataset = dataset.select(selected_cols)
|
798
872
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
799
873
|
transform_kwargs = dict(
|
800
874
|
session=dataset._session,
|
801
|
-
dependencies=
|
875
|
+
dependencies=self._deps,
|
802
876
|
score_sproc_imports=['sklearn'],
|
803
877
|
)
|
804
878
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -863,11 +937,8 @@ class VarianceThreshold(BaseTransformer):
|
|
863
937
|
|
864
938
|
if isinstance(dataset, DataFrame):
|
865
939
|
|
866
|
-
self.
|
867
|
-
|
868
|
-
inference_method=inference_method,
|
869
|
-
|
870
|
-
)
|
940
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
941
|
+
self._deps = self._get_dependencies()
|
871
942
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
872
943
|
transform_kwargs = dict(
|
873
944
|
session = dataset._session,
|
@@ -900,50 +971,84 @@ class VarianceThreshold(BaseTransformer):
|
|
900
971
|
)
|
901
972
|
return output_df
|
902
973
|
|
974
|
+
|
975
|
+
|
976
|
+
def to_sklearn(self) -> Any:
|
977
|
+
"""Get sklearn.feature_selection.VarianceThreshold object.
|
978
|
+
"""
|
979
|
+
if self._sklearn_object is None:
|
980
|
+
self._sklearn_object = self._create_sklearn_object()
|
981
|
+
return self._sklearn_object
|
982
|
+
|
983
|
+
def to_xgboost(self) -> Any:
|
984
|
+
raise exceptions.SnowflakeMLException(
|
985
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
986
|
+
original_exception=AttributeError(
|
987
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
988
|
+
"to_xgboost()",
|
989
|
+
"to_sklearn()"
|
990
|
+
)
|
991
|
+
),
|
992
|
+
)
|
903
993
|
|
904
|
-
def
|
994
|
+
def to_lightgbm(self) -> Any:
|
995
|
+
raise exceptions.SnowflakeMLException(
|
996
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
997
|
+
original_exception=AttributeError(
|
998
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
999
|
+
"to_lightgbm()",
|
1000
|
+
"to_sklearn()"
|
1001
|
+
)
|
1002
|
+
),
|
1003
|
+
)
|
1004
|
+
|
1005
|
+
def _get_dependencies(self) -> List[str]:
|
1006
|
+
return self._deps
|
1007
|
+
|
1008
|
+
|
1009
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
905
1010
|
self._model_signature_dict = dict()
|
906
1011
|
|
907
1012
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
908
1013
|
|
909
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1014
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
910
1015
|
outputs: List[BaseFeatureSpec] = []
|
911
1016
|
if hasattr(self, "predict"):
|
912
1017
|
# keep mypy happy
|
913
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1018
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
914
1019
|
# For classifier, the type of predict is the same as the type of label
|
915
|
-
if self._sklearn_object._estimator_type ==
|
916
|
-
|
1020
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1021
|
+
# label columns is the desired type for output
|
917
1022
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
918
1023
|
# rename the output columns
|
919
1024
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
920
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
921
|
-
|
922
|
-
|
1025
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1026
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1027
|
+
)
|
923
1028
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
924
1029
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
925
|
-
# Clusterer returns int64 cluster labels.
|
1030
|
+
# Clusterer returns int64 cluster labels.
|
926
1031
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
927
1032
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
928
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
929
|
-
|
930
|
-
|
931
|
-
|
1033
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1034
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1035
|
+
)
|
1036
|
+
|
932
1037
|
# For regressor, the type of predict is float64
|
933
|
-
elif self._sklearn_object._estimator_type ==
|
1038
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
934
1039
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
935
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
936
|
-
|
937
|
-
|
938
|
-
|
1040
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1041
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1042
|
+
)
|
1043
|
+
|
939
1044
|
for prob_func in PROB_FUNCTIONS:
|
940
1045
|
if hasattr(self, prob_func):
|
941
1046
|
output_cols_prefix: str = f"{prob_func}_"
|
942
1047
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
943
1048
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
944
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
945
|
-
|
946
|
-
|
1049
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1050
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1051
|
+
)
|
947
1052
|
|
948
1053
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
949
1054
|
items = list(self._model_signature_dict.items())
|
@@ -956,10 +1061,10 @@ class VarianceThreshold(BaseTransformer):
|
|
956
1061
|
"""Returns model signature of current class.
|
957
1062
|
|
958
1063
|
Raises:
|
959
|
-
|
1064
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
960
1065
|
|
961
1066
|
Returns:
|
962
|
-
Dict
|
1067
|
+
Dict with each method and its input output signature
|
963
1068
|
"""
|
964
1069
|
if self._model_signature_dict is None:
|
965
1070
|
raise exceptions.SnowflakeMLException(
|
@@ -967,35 +1072,3 @@ class VarianceThreshold(BaseTransformer):
|
|
967
1072
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
968
1073
|
)
|
969
1074
|
return self._model_signature_dict
|
970
|
-
|
971
|
-
def to_sklearn(self) -> Any:
|
972
|
-
"""Get sklearn.feature_selection.VarianceThreshold object.
|
973
|
-
"""
|
974
|
-
if self._sklearn_object is None:
|
975
|
-
self._sklearn_object = self._create_sklearn_object()
|
976
|
-
return self._sklearn_object
|
977
|
-
|
978
|
-
def to_xgboost(self) -> Any:
|
979
|
-
raise exceptions.SnowflakeMLException(
|
980
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
981
|
-
original_exception=AttributeError(
|
982
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
983
|
-
"to_xgboost()",
|
984
|
-
"to_sklearn()"
|
985
|
-
)
|
986
|
-
),
|
987
|
-
)
|
988
|
-
|
989
|
-
def to_lightgbm(self) -> Any:
|
990
|
-
raise exceptions.SnowflakeMLException(
|
991
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
992
|
-
original_exception=AttributeError(
|
993
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
994
|
-
"to_lightgbm()",
|
995
|
-
"to_sklearn()"
|
996
|
-
)
|
997
|
-
),
|
998
|
-
)
|
999
|
-
|
1000
|
-
def _get_dependencies(self) -> List[str]:
|
1001
|
-
return self._deps
|