snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class BayesianRidge(BaseTransformer):
|
71
64
|
r"""Bayesian ridge regression
|
72
65
|
For more details on this class, see [sklearn.linear_model.BayesianRidge]
|
@@ -269,12 +262,7 @@ class BayesianRidge(BaseTransformer):
|
|
269
262
|
)
|
270
263
|
return selected_cols
|
271
264
|
|
272
|
-
|
273
|
-
project=_PROJECT,
|
274
|
-
subproject=_SUBPROJECT,
|
275
|
-
custom_tags=dict([("autogen", True)]),
|
276
|
-
)
|
277
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "BayesianRidge":
|
265
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "BayesianRidge":
|
278
266
|
"""Fit the model
|
279
267
|
For more details on this function, see [sklearn.linear_model.BayesianRidge.fit]
|
280
268
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.BayesianRidge.html#sklearn.linear_model.BayesianRidge.fit)
|
@@ -301,12 +289,14 @@ class BayesianRidge(BaseTransformer):
|
|
301
289
|
|
302
290
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
303
291
|
|
304
|
-
|
292
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
305
293
|
if SNOWML_SPROC_ENV in os.environ:
|
306
294
|
statement_params = telemetry.get_function_usage_statement_params(
|
307
295
|
project=_PROJECT,
|
308
296
|
subproject=_SUBPROJECT,
|
309
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
297
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
298
|
+
inspect.currentframe(), BayesianRidge.__class__.__name__
|
299
|
+
),
|
310
300
|
api_calls=[Session.call],
|
311
301
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
312
302
|
)
|
@@ -327,27 +317,24 @@ class BayesianRidge(BaseTransformer):
|
|
327
317
|
)
|
328
318
|
self._sklearn_object = model_trainer.train()
|
329
319
|
self._is_fitted = True
|
330
|
-
self.
|
320
|
+
self._generate_model_signatures(dataset)
|
331
321
|
return self
|
332
322
|
|
333
323
|
def _batch_inference_validate_snowpark(
|
334
324
|
self,
|
335
325
|
dataset: DataFrame,
|
336
326
|
inference_method: str,
|
337
|
-
) ->
|
338
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
339
|
-
return the available package that exists in the snowflake anaconda channel
|
327
|
+
) -> None:
|
328
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
340
329
|
|
341
330
|
Args:
|
342
331
|
dataset: snowpark dataframe
|
343
332
|
inference_method: the inference method such as predict, score...
|
344
|
-
|
333
|
+
|
345
334
|
Raises:
|
346
335
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
347
336
|
SnowflakeMLException: If the session is None, raise error
|
348
337
|
|
349
|
-
Returns:
|
350
|
-
A list of available package that exists in the snowflake anaconda channel
|
351
338
|
"""
|
352
339
|
if not self._is_fitted:
|
353
340
|
raise exceptions.SnowflakeMLException(
|
@@ -365,9 +352,7 @@ class BayesianRidge(BaseTransformer):
|
|
365
352
|
"Session must not specified for snowpark dataset."
|
366
353
|
),
|
367
354
|
)
|
368
|
-
|
369
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
370
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
355
|
+
|
371
356
|
|
372
357
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
373
358
|
@telemetry.send_api_usage_telemetry(
|
@@ -403,7 +388,9 @@ class BayesianRidge(BaseTransformer):
|
|
403
388
|
# when it is classifier, infer the datatype from label columns
|
404
389
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
405
390
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
406
|
-
label_cols_signatures = [
|
391
|
+
label_cols_signatures = [
|
392
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
393
|
+
]
|
407
394
|
if len(label_cols_signatures) == 0:
|
408
395
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
409
396
|
raise exceptions.SnowflakeMLException(
|
@@ -411,25 +398,23 @@ class BayesianRidge(BaseTransformer):
|
|
411
398
|
original_exception=ValueError(error_str),
|
412
399
|
)
|
413
400
|
|
414
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
415
|
-
label_cols_signatures[0].as_snowpark_type()
|
416
|
-
)
|
401
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
417
402
|
|
418
|
-
self.
|
419
|
-
|
403
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
404
|
+
self._deps = self._get_dependencies()
|
405
|
+
assert isinstance(
|
406
|
+
dataset._session, Session
|
407
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
420
408
|
|
421
409
|
transform_kwargs = dict(
|
422
|
-
session
|
423
|
-
dependencies
|
424
|
-
drop_input_cols
|
425
|
-
expected_output_cols_type
|
410
|
+
session=dataset._session,
|
411
|
+
dependencies=self._deps,
|
412
|
+
drop_input_cols=self._drop_input_cols,
|
413
|
+
expected_output_cols_type=expected_type_inferred,
|
426
414
|
)
|
427
415
|
|
428
416
|
elif isinstance(dataset, pd.DataFrame):
|
429
|
-
transform_kwargs = dict(
|
430
|
-
snowpark_input_cols = self._snowpark_cols,
|
431
|
-
drop_input_cols = self._drop_input_cols
|
432
|
-
)
|
417
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
433
418
|
|
434
419
|
transform_handlers = ModelTransformerBuilder.build(
|
435
420
|
dataset=dataset,
|
@@ -469,7 +454,7 @@ class BayesianRidge(BaseTransformer):
|
|
469
454
|
Transformed dataset.
|
470
455
|
"""
|
471
456
|
super()._check_dataset_type(dataset)
|
472
|
-
inference_method="transform"
|
457
|
+
inference_method = "transform"
|
473
458
|
|
474
459
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
475
460
|
# are specific to the type of dataset used.
|
@@ -499,24 +484,19 @@ class BayesianRidge(BaseTransformer):
|
|
499
484
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
500
485
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
501
486
|
|
502
|
-
self.
|
503
|
-
|
504
|
-
inference_method=inference_method,
|
505
|
-
)
|
487
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
488
|
+
self._deps = self._get_dependencies()
|
506
489
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
507
490
|
|
508
491
|
transform_kwargs = dict(
|
509
|
-
session
|
510
|
-
dependencies
|
511
|
-
drop_input_cols
|
512
|
-
expected_output_cols_type
|
492
|
+
session=dataset._session,
|
493
|
+
dependencies=self._deps,
|
494
|
+
drop_input_cols=self._drop_input_cols,
|
495
|
+
expected_output_cols_type=expected_dtype,
|
513
496
|
)
|
514
497
|
|
515
498
|
elif isinstance(dataset, pd.DataFrame):
|
516
|
-
transform_kwargs = dict(
|
517
|
-
snowpark_input_cols = self._snowpark_cols,
|
518
|
-
drop_input_cols = self._drop_input_cols
|
519
|
-
)
|
499
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
520
500
|
|
521
501
|
transform_handlers = ModelTransformerBuilder.build(
|
522
502
|
dataset=dataset,
|
@@ -535,7 +515,11 @@ class BayesianRidge(BaseTransformer):
|
|
535
515
|
return output_df
|
536
516
|
|
537
517
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
538
|
-
def fit_predict(
|
518
|
+
def fit_predict(
|
519
|
+
self,
|
520
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
521
|
+
output_cols_prefix: str = "fit_predict_",
|
522
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
539
523
|
""" Method not supported for this class.
|
540
524
|
|
541
525
|
|
@@ -560,22 +544,104 @@ class BayesianRidge(BaseTransformer):
|
|
560
544
|
)
|
561
545
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
562
546
|
drop_input_cols=self._drop_input_cols,
|
563
|
-
expected_output_cols_list=
|
547
|
+
expected_output_cols_list=(
|
548
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
549
|
+
),
|
564
550
|
)
|
565
551
|
self._sklearn_object = fitted_estimator
|
566
552
|
self._is_fitted = True
|
567
553
|
return output_result
|
568
554
|
|
555
|
+
|
556
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
557
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
558
|
+
""" Method not supported for this class.
|
559
|
+
|
569
560
|
|
570
|
-
|
571
|
-
|
572
|
-
|
561
|
+
Raises:
|
562
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
563
|
+
|
564
|
+
Args:
|
565
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
566
|
+
Snowpark or Pandas DataFrame.
|
567
|
+
output_cols_prefix: Prefix for the response columns
|
573
568
|
Returns:
|
574
569
|
Transformed dataset.
|
575
570
|
"""
|
576
|
-
self.
|
577
|
-
|
578
|
-
|
571
|
+
self._infer_input_output_cols(dataset)
|
572
|
+
super()._check_dataset_type(dataset)
|
573
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
574
|
+
estimator=self._sklearn_object,
|
575
|
+
dataset=dataset,
|
576
|
+
input_cols=self.input_cols,
|
577
|
+
label_cols=self.label_cols,
|
578
|
+
sample_weight_col=self.sample_weight_col,
|
579
|
+
autogenerated=self._autogenerated,
|
580
|
+
subproject=_SUBPROJECT,
|
581
|
+
)
|
582
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
583
|
+
drop_input_cols=self._drop_input_cols,
|
584
|
+
expected_output_cols_list=self.output_cols,
|
585
|
+
)
|
586
|
+
self._sklearn_object = fitted_estimator
|
587
|
+
self._is_fitted = True
|
588
|
+
return output_result
|
589
|
+
|
590
|
+
|
591
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
592
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
593
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
594
|
+
"""
|
595
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
596
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
597
|
+
if output_cols:
|
598
|
+
output_cols = [
|
599
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
600
|
+
for c in output_cols
|
601
|
+
]
|
602
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
603
|
+
output_cols = [output_cols_prefix]
|
604
|
+
elif self._sklearn_object is not None:
|
605
|
+
classes = self._sklearn_object.classes_
|
606
|
+
if isinstance(classes, numpy.ndarray):
|
607
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
608
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
609
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
610
|
+
output_cols = []
|
611
|
+
for i, cl in enumerate(classes):
|
612
|
+
# For binary classification, there is only one output column for each class
|
613
|
+
# ndarray as the two classes are complementary.
|
614
|
+
if len(cl) == 2:
|
615
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
616
|
+
else:
|
617
|
+
output_cols.extend([
|
618
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
619
|
+
])
|
620
|
+
else:
|
621
|
+
output_cols = []
|
622
|
+
|
623
|
+
# Make sure column names are valid snowflake identifiers.
|
624
|
+
assert output_cols is not None # Make MyPy happy
|
625
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
626
|
+
|
627
|
+
return rv
|
628
|
+
|
629
|
+
def _align_expected_output_names(
|
630
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
631
|
+
) -> List[str]:
|
632
|
+
# in case the inferred output column names dimension is different
|
633
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
634
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
635
|
+
output_df_columns = list(output_df_pd.columns)
|
636
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
637
|
+
if self.sample_weight_col:
|
638
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
639
|
+
# if the dimension of inferred output column names is correct; use it
|
640
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
641
|
+
return expected_output_cols_list
|
642
|
+
# otherwise, use the sklearn estimator's output
|
643
|
+
else:
|
644
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
579
645
|
|
580
646
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
581
647
|
@telemetry.send_api_usage_telemetry(
|
@@ -607,24 +673,26 @@ class BayesianRidge(BaseTransformer):
|
|
607
673
|
# are specific to the type of dataset used.
|
608
674
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
609
675
|
|
676
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
677
|
+
|
610
678
|
if isinstance(dataset, DataFrame):
|
611
|
-
self.
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
679
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
680
|
+
self._deps = self._get_dependencies()
|
681
|
+
assert isinstance(
|
682
|
+
dataset._session, Session
|
683
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
616
684
|
transform_kwargs = dict(
|
617
685
|
session=dataset._session,
|
618
686
|
dependencies=self._deps,
|
619
|
-
drop_input_cols
|
687
|
+
drop_input_cols=self._drop_input_cols,
|
620
688
|
expected_output_cols_type="float",
|
621
689
|
)
|
690
|
+
expected_output_cols = self._align_expected_output_names(
|
691
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
692
|
+
)
|
622
693
|
|
623
694
|
elif isinstance(dataset, pd.DataFrame):
|
624
|
-
transform_kwargs = dict(
|
625
|
-
snowpark_input_cols = self._snowpark_cols,
|
626
|
-
drop_input_cols = self._drop_input_cols
|
627
|
-
)
|
695
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
628
696
|
|
629
697
|
transform_handlers = ModelTransformerBuilder.build(
|
630
698
|
dataset=dataset,
|
@@ -636,7 +704,7 @@ class BayesianRidge(BaseTransformer):
|
|
636
704
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
637
705
|
inference_method=inference_method,
|
638
706
|
input_cols=self.input_cols,
|
639
|
-
expected_output_cols=
|
707
|
+
expected_output_cols=expected_output_cols,
|
640
708
|
**transform_kwargs
|
641
709
|
)
|
642
710
|
return output_df
|
@@ -666,29 +734,30 @@ class BayesianRidge(BaseTransformer):
|
|
666
734
|
Output dataset with log probability of the sample for each class in the model.
|
667
735
|
"""
|
668
736
|
super()._check_dataset_type(dataset)
|
669
|
-
inference_method="predict_log_proba"
|
737
|
+
inference_method = "predict_log_proba"
|
738
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
670
739
|
|
671
740
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
672
741
|
# are specific to the type of dataset used.
|
673
742
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
674
743
|
|
675
744
|
if isinstance(dataset, DataFrame):
|
676
|
-
self.
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
745
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
746
|
+
self._deps = self._get_dependencies()
|
747
|
+
assert isinstance(
|
748
|
+
dataset._session, Session
|
749
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
681
750
|
transform_kwargs = dict(
|
682
751
|
session=dataset._session,
|
683
752
|
dependencies=self._deps,
|
684
|
-
drop_input_cols
|
753
|
+
drop_input_cols=self._drop_input_cols,
|
685
754
|
expected_output_cols_type="float",
|
686
755
|
)
|
756
|
+
expected_output_cols = self._align_expected_output_names(
|
757
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
758
|
+
)
|
687
759
|
elif isinstance(dataset, pd.DataFrame):
|
688
|
-
transform_kwargs = dict(
|
689
|
-
snowpark_input_cols = self._snowpark_cols,
|
690
|
-
drop_input_cols = self._drop_input_cols
|
691
|
-
)
|
760
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
692
761
|
|
693
762
|
transform_handlers = ModelTransformerBuilder.build(
|
694
763
|
dataset=dataset,
|
@@ -701,7 +770,7 @@ class BayesianRidge(BaseTransformer):
|
|
701
770
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
702
771
|
inference_method=inference_method,
|
703
772
|
input_cols=self.input_cols,
|
704
|
-
expected_output_cols=
|
773
|
+
expected_output_cols=expected_output_cols,
|
705
774
|
**transform_kwargs
|
706
775
|
)
|
707
776
|
return output_df
|
@@ -727,30 +796,32 @@ class BayesianRidge(BaseTransformer):
|
|
727
796
|
Output dataset with results of the decision function for the samples in input dataset.
|
728
797
|
"""
|
729
798
|
super()._check_dataset_type(dataset)
|
730
|
-
inference_method="decision_function"
|
799
|
+
inference_method = "decision_function"
|
731
800
|
|
732
801
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
733
802
|
# are specific to the type of dataset used.
|
734
803
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
735
804
|
|
805
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
806
|
+
|
736
807
|
if isinstance(dataset, DataFrame):
|
737
|
-
self.
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
808
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
809
|
+
self._deps = self._get_dependencies()
|
810
|
+
assert isinstance(
|
811
|
+
dataset._session, Session
|
812
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
742
813
|
transform_kwargs = dict(
|
743
814
|
session=dataset._session,
|
744
815
|
dependencies=self._deps,
|
745
|
-
drop_input_cols
|
816
|
+
drop_input_cols=self._drop_input_cols,
|
746
817
|
expected_output_cols_type="float",
|
747
818
|
)
|
819
|
+
expected_output_cols = self._align_expected_output_names(
|
820
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
821
|
+
)
|
748
822
|
|
749
823
|
elif isinstance(dataset, pd.DataFrame):
|
750
|
-
transform_kwargs = dict(
|
751
|
-
snowpark_input_cols = self._snowpark_cols,
|
752
|
-
drop_input_cols = self._drop_input_cols
|
753
|
-
)
|
824
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
754
825
|
|
755
826
|
transform_handlers = ModelTransformerBuilder.build(
|
756
827
|
dataset=dataset,
|
@@ -763,7 +834,7 @@ class BayesianRidge(BaseTransformer):
|
|
763
834
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
764
835
|
inference_method=inference_method,
|
765
836
|
input_cols=self.input_cols,
|
766
|
-
expected_output_cols=
|
837
|
+
expected_output_cols=expected_output_cols,
|
767
838
|
**transform_kwargs
|
768
839
|
)
|
769
840
|
return output_df
|
@@ -792,17 +863,17 @@ class BayesianRidge(BaseTransformer):
|
|
792
863
|
Output dataset with probability of the sample for each class in the model.
|
793
864
|
"""
|
794
865
|
super()._check_dataset_type(dataset)
|
795
|
-
inference_method="score_samples"
|
866
|
+
inference_method = "score_samples"
|
796
867
|
|
797
868
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
798
869
|
# are specific to the type of dataset used.
|
799
870
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
800
871
|
|
872
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
873
|
+
|
801
874
|
if isinstance(dataset, DataFrame):
|
802
|
-
self.
|
803
|
-
|
804
|
-
inference_method=inference_method,
|
805
|
-
)
|
875
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
876
|
+
self._deps = self._get_dependencies()
|
806
877
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
807
878
|
transform_kwargs = dict(
|
808
879
|
session=dataset._session,
|
@@ -810,6 +881,9 @@ class BayesianRidge(BaseTransformer):
|
|
810
881
|
drop_input_cols = self._drop_input_cols,
|
811
882
|
expected_output_cols_type="float",
|
812
883
|
)
|
884
|
+
expected_output_cols = self._align_expected_output_names(
|
885
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
886
|
+
)
|
813
887
|
|
814
888
|
elif isinstance(dataset, pd.DataFrame):
|
815
889
|
transform_kwargs = dict(
|
@@ -828,7 +902,7 @@ class BayesianRidge(BaseTransformer):
|
|
828
902
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
829
903
|
inference_method=inference_method,
|
830
904
|
input_cols=self.input_cols,
|
831
|
-
expected_output_cols=
|
905
|
+
expected_output_cols=expected_output_cols,
|
832
906
|
**transform_kwargs
|
833
907
|
)
|
834
908
|
return output_df
|
@@ -863,17 +937,15 @@ class BayesianRidge(BaseTransformer):
|
|
863
937
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
864
938
|
|
865
939
|
if isinstance(dataset, DataFrame):
|
866
|
-
self.
|
867
|
-
|
868
|
-
inference_method="score",
|
869
|
-
)
|
940
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
941
|
+
self._deps = self._get_dependencies()
|
870
942
|
selected_cols = self._get_active_columns()
|
871
943
|
if len(selected_cols) > 0:
|
872
944
|
dataset = dataset.select(selected_cols)
|
873
945
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
874
946
|
transform_kwargs = dict(
|
875
947
|
session=dataset._session,
|
876
|
-
dependencies=
|
948
|
+
dependencies=self._deps,
|
877
949
|
score_sproc_imports=['sklearn'],
|
878
950
|
)
|
879
951
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -938,11 +1010,8 @@ class BayesianRidge(BaseTransformer):
|
|
938
1010
|
|
939
1011
|
if isinstance(dataset, DataFrame):
|
940
1012
|
|
941
|
-
self.
|
942
|
-
|
943
|
-
inference_method=inference_method,
|
944
|
-
|
945
|
-
)
|
1013
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1014
|
+
self._deps = self._get_dependencies()
|
946
1015
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
947
1016
|
transform_kwargs = dict(
|
948
1017
|
session = dataset._session,
|
@@ -975,50 +1044,84 @@ class BayesianRidge(BaseTransformer):
|
|
975
1044
|
)
|
976
1045
|
return output_df
|
977
1046
|
|
1047
|
+
|
1048
|
+
|
1049
|
+
def to_sklearn(self) -> Any:
|
1050
|
+
"""Get sklearn.linear_model.BayesianRidge object.
|
1051
|
+
"""
|
1052
|
+
if self._sklearn_object is None:
|
1053
|
+
self._sklearn_object = self._create_sklearn_object()
|
1054
|
+
return self._sklearn_object
|
1055
|
+
|
1056
|
+
def to_xgboost(self) -> Any:
|
1057
|
+
raise exceptions.SnowflakeMLException(
|
1058
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1059
|
+
original_exception=AttributeError(
|
1060
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1061
|
+
"to_xgboost()",
|
1062
|
+
"to_sklearn()"
|
1063
|
+
)
|
1064
|
+
),
|
1065
|
+
)
|
1066
|
+
|
1067
|
+
def to_lightgbm(self) -> Any:
|
1068
|
+
raise exceptions.SnowflakeMLException(
|
1069
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1070
|
+
original_exception=AttributeError(
|
1071
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1072
|
+
"to_lightgbm()",
|
1073
|
+
"to_sklearn()"
|
1074
|
+
)
|
1075
|
+
),
|
1076
|
+
)
|
1077
|
+
|
1078
|
+
def _get_dependencies(self) -> List[str]:
|
1079
|
+
return self._deps
|
1080
|
+
|
978
1081
|
|
979
|
-
def
|
1082
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
980
1083
|
self._model_signature_dict = dict()
|
981
1084
|
|
982
1085
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
983
1086
|
|
984
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1087
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
985
1088
|
outputs: List[BaseFeatureSpec] = []
|
986
1089
|
if hasattr(self, "predict"):
|
987
1090
|
# keep mypy happy
|
988
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1091
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
989
1092
|
# For classifier, the type of predict is the same as the type of label
|
990
|
-
if self._sklearn_object._estimator_type ==
|
991
|
-
|
1093
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1094
|
+
# label columns is the desired type for output
|
992
1095
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
993
1096
|
# rename the output columns
|
994
1097
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
995
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
996
|
-
|
997
|
-
|
1098
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1099
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1100
|
+
)
|
998
1101
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
999
1102
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1000
|
-
# Clusterer returns int64 cluster labels.
|
1103
|
+
# Clusterer returns int64 cluster labels.
|
1001
1104
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1002
1105
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1003
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1106
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1107
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1108
|
+
)
|
1109
|
+
|
1007
1110
|
# For regressor, the type of predict is float64
|
1008
|
-
elif self._sklearn_object._estimator_type ==
|
1111
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1009
1112
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1010
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1113
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1114
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1115
|
+
)
|
1116
|
+
|
1014
1117
|
for prob_func in PROB_FUNCTIONS:
|
1015
1118
|
if hasattr(self, prob_func):
|
1016
1119
|
output_cols_prefix: str = f"{prob_func}_"
|
1017
1120
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1018
1121
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1019
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1020
|
-
|
1021
|
-
|
1122
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1123
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1124
|
+
)
|
1022
1125
|
|
1023
1126
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1024
1127
|
items = list(self._model_signature_dict.items())
|
@@ -1031,10 +1134,10 @@ class BayesianRidge(BaseTransformer):
|
|
1031
1134
|
"""Returns model signature of current class.
|
1032
1135
|
|
1033
1136
|
Raises:
|
1034
|
-
|
1137
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1035
1138
|
|
1036
1139
|
Returns:
|
1037
|
-
Dict
|
1140
|
+
Dict with each method and its input output signature
|
1038
1141
|
"""
|
1039
1142
|
if self._model_signature_dict is None:
|
1040
1143
|
raise exceptions.SnowflakeMLException(
|
@@ -1042,35 +1145,3 @@ class BayesianRidge(BaseTransformer):
|
|
1042
1145
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1043
1146
|
)
|
1044
1147
|
return self._model_signature_dict
|
1045
|
-
|
1046
|
-
def to_sklearn(self) -> Any:
|
1047
|
-
"""Get sklearn.linear_model.BayesianRidge object.
|
1048
|
-
"""
|
1049
|
-
if self._sklearn_object is None:
|
1050
|
-
self._sklearn_object = self._create_sklearn_object()
|
1051
|
-
return self._sklearn_object
|
1052
|
-
|
1053
|
-
def to_xgboost(self) -> Any:
|
1054
|
-
raise exceptions.SnowflakeMLException(
|
1055
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1056
|
-
original_exception=AttributeError(
|
1057
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1058
|
-
"to_xgboost()",
|
1059
|
-
"to_sklearn()"
|
1060
|
-
)
|
1061
|
-
),
|
1062
|
-
)
|
1063
|
-
|
1064
|
-
def to_lightgbm(self) -> Any:
|
1065
|
-
raise exceptions.SnowflakeMLException(
|
1066
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1067
|
-
original_exception=AttributeError(
|
1068
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1069
|
-
"to_lightgbm()",
|
1070
|
-
"to_sklearn()"
|
1071
|
-
)
|
1072
|
-
),
|
1073
|
-
)
|
1074
|
-
|
1075
|
-
def _get_dependencies(self) -> List[str]:
|
1076
|
-
return self._deps
|