snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class ElasticNet(BaseTransformer):
|
71
64
|
r"""Linear regression with combined L1 and L2 priors as regularizer
|
72
65
|
For more details on this class, see [sklearn.linear_model.ElasticNet]
|
@@ -268,12 +261,7 @@ class ElasticNet(BaseTransformer):
|
|
268
261
|
)
|
269
262
|
return selected_cols
|
270
263
|
|
271
|
-
|
272
|
-
project=_PROJECT,
|
273
|
-
subproject=_SUBPROJECT,
|
274
|
-
custom_tags=dict([("autogen", True)]),
|
275
|
-
)
|
276
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ElasticNet":
|
264
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ElasticNet":
|
277
265
|
"""Fit model with coordinate descent
|
278
266
|
For more details on this function, see [sklearn.linear_model.ElasticNet.fit]
|
279
267
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html#sklearn.linear_model.ElasticNet.fit)
|
@@ -300,12 +288,14 @@ class ElasticNet(BaseTransformer):
|
|
300
288
|
|
301
289
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
302
290
|
|
303
|
-
|
291
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
304
292
|
if SNOWML_SPROC_ENV in os.environ:
|
305
293
|
statement_params = telemetry.get_function_usage_statement_params(
|
306
294
|
project=_PROJECT,
|
307
295
|
subproject=_SUBPROJECT,
|
308
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
296
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
297
|
+
inspect.currentframe(), ElasticNet.__class__.__name__
|
298
|
+
),
|
309
299
|
api_calls=[Session.call],
|
310
300
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
311
301
|
)
|
@@ -326,27 +316,24 @@ class ElasticNet(BaseTransformer):
|
|
326
316
|
)
|
327
317
|
self._sklearn_object = model_trainer.train()
|
328
318
|
self._is_fitted = True
|
329
|
-
self.
|
319
|
+
self._generate_model_signatures(dataset)
|
330
320
|
return self
|
331
321
|
|
332
322
|
def _batch_inference_validate_snowpark(
|
333
323
|
self,
|
334
324
|
dataset: DataFrame,
|
335
325
|
inference_method: str,
|
336
|
-
) ->
|
337
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
338
|
-
return the available package that exists in the snowflake anaconda channel
|
326
|
+
) -> None:
|
327
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
339
328
|
|
340
329
|
Args:
|
341
330
|
dataset: snowpark dataframe
|
342
331
|
inference_method: the inference method such as predict, score...
|
343
|
-
|
332
|
+
|
344
333
|
Raises:
|
345
334
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
346
335
|
SnowflakeMLException: If the session is None, raise error
|
347
336
|
|
348
|
-
Returns:
|
349
|
-
A list of available package that exists in the snowflake anaconda channel
|
350
337
|
"""
|
351
338
|
if not self._is_fitted:
|
352
339
|
raise exceptions.SnowflakeMLException(
|
@@ -364,9 +351,7 @@ class ElasticNet(BaseTransformer):
|
|
364
351
|
"Session must not specified for snowpark dataset."
|
365
352
|
),
|
366
353
|
)
|
367
|
-
|
368
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
369
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
354
|
+
|
370
355
|
|
371
356
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
372
357
|
@telemetry.send_api_usage_telemetry(
|
@@ -402,7 +387,9 @@ class ElasticNet(BaseTransformer):
|
|
402
387
|
# when it is classifier, infer the datatype from label columns
|
403
388
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
404
389
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
405
|
-
label_cols_signatures = [
|
390
|
+
label_cols_signatures = [
|
391
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
392
|
+
]
|
406
393
|
if len(label_cols_signatures) == 0:
|
407
394
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
408
395
|
raise exceptions.SnowflakeMLException(
|
@@ -410,25 +397,23 @@ class ElasticNet(BaseTransformer):
|
|
410
397
|
original_exception=ValueError(error_str),
|
411
398
|
)
|
412
399
|
|
413
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
414
|
-
label_cols_signatures[0].as_snowpark_type()
|
415
|
-
)
|
400
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
416
401
|
|
417
|
-
self.
|
418
|
-
|
402
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
403
|
+
self._deps = self._get_dependencies()
|
404
|
+
assert isinstance(
|
405
|
+
dataset._session, Session
|
406
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
419
407
|
|
420
408
|
transform_kwargs = dict(
|
421
|
-
session
|
422
|
-
dependencies
|
423
|
-
drop_input_cols
|
424
|
-
expected_output_cols_type
|
409
|
+
session=dataset._session,
|
410
|
+
dependencies=self._deps,
|
411
|
+
drop_input_cols=self._drop_input_cols,
|
412
|
+
expected_output_cols_type=expected_type_inferred,
|
425
413
|
)
|
426
414
|
|
427
415
|
elif isinstance(dataset, pd.DataFrame):
|
428
|
-
transform_kwargs = dict(
|
429
|
-
snowpark_input_cols = self._snowpark_cols,
|
430
|
-
drop_input_cols = self._drop_input_cols
|
431
|
-
)
|
416
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
432
417
|
|
433
418
|
transform_handlers = ModelTransformerBuilder.build(
|
434
419
|
dataset=dataset,
|
@@ -468,7 +453,7 @@ class ElasticNet(BaseTransformer):
|
|
468
453
|
Transformed dataset.
|
469
454
|
"""
|
470
455
|
super()._check_dataset_type(dataset)
|
471
|
-
inference_method="transform"
|
456
|
+
inference_method = "transform"
|
472
457
|
|
473
458
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
474
459
|
# are specific to the type of dataset used.
|
@@ -498,24 +483,19 @@ class ElasticNet(BaseTransformer):
|
|
498
483
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
499
484
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
500
485
|
|
501
|
-
self.
|
502
|
-
|
503
|
-
inference_method=inference_method,
|
504
|
-
)
|
486
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
487
|
+
self._deps = self._get_dependencies()
|
505
488
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
506
489
|
|
507
490
|
transform_kwargs = dict(
|
508
|
-
session
|
509
|
-
dependencies
|
510
|
-
drop_input_cols
|
511
|
-
expected_output_cols_type
|
491
|
+
session=dataset._session,
|
492
|
+
dependencies=self._deps,
|
493
|
+
drop_input_cols=self._drop_input_cols,
|
494
|
+
expected_output_cols_type=expected_dtype,
|
512
495
|
)
|
513
496
|
|
514
497
|
elif isinstance(dataset, pd.DataFrame):
|
515
|
-
transform_kwargs = dict(
|
516
|
-
snowpark_input_cols = self._snowpark_cols,
|
517
|
-
drop_input_cols = self._drop_input_cols
|
518
|
-
)
|
498
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
519
499
|
|
520
500
|
transform_handlers = ModelTransformerBuilder.build(
|
521
501
|
dataset=dataset,
|
@@ -534,7 +514,11 @@ class ElasticNet(BaseTransformer):
|
|
534
514
|
return output_df
|
535
515
|
|
536
516
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
537
|
-
def fit_predict(
|
517
|
+
def fit_predict(
|
518
|
+
self,
|
519
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
520
|
+
output_cols_prefix: str = "fit_predict_",
|
521
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
538
522
|
""" Method not supported for this class.
|
539
523
|
|
540
524
|
|
@@ -559,22 +543,104 @@ class ElasticNet(BaseTransformer):
|
|
559
543
|
)
|
560
544
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
561
545
|
drop_input_cols=self._drop_input_cols,
|
562
|
-
expected_output_cols_list=
|
546
|
+
expected_output_cols_list=(
|
547
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
548
|
+
),
|
563
549
|
)
|
564
550
|
self._sklearn_object = fitted_estimator
|
565
551
|
self._is_fitted = True
|
566
552
|
return output_result
|
567
553
|
|
554
|
+
|
555
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
556
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
557
|
+
""" Method not supported for this class.
|
558
|
+
|
568
559
|
|
569
|
-
|
570
|
-
|
571
|
-
|
560
|
+
Raises:
|
561
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
562
|
+
|
563
|
+
Args:
|
564
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
565
|
+
Snowpark or Pandas DataFrame.
|
566
|
+
output_cols_prefix: Prefix for the response columns
|
572
567
|
Returns:
|
573
568
|
Transformed dataset.
|
574
569
|
"""
|
575
|
-
self.
|
576
|
-
|
577
|
-
|
570
|
+
self._infer_input_output_cols(dataset)
|
571
|
+
super()._check_dataset_type(dataset)
|
572
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
573
|
+
estimator=self._sklearn_object,
|
574
|
+
dataset=dataset,
|
575
|
+
input_cols=self.input_cols,
|
576
|
+
label_cols=self.label_cols,
|
577
|
+
sample_weight_col=self.sample_weight_col,
|
578
|
+
autogenerated=self._autogenerated,
|
579
|
+
subproject=_SUBPROJECT,
|
580
|
+
)
|
581
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
582
|
+
drop_input_cols=self._drop_input_cols,
|
583
|
+
expected_output_cols_list=self.output_cols,
|
584
|
+
)
|
585
|
+
self._sklearn_object = fitted_estimator
|
586
|
+
self._is_fitted = True
|
587
|
+
return output_result
|
588
|
+
|
589
|
+
|
590
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
591
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
592
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
593
|
+
"""
|
594
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
595
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
596
|
+
if output_cols:
|
597
|
+
output_cols = [
|
598
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
599
|
+
for c in output_cols
|
600
|
+
]
|
601
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
602
|
+
output_cols = [output_cols_prefix]
|
603
|
+
elif self._sklearn_object is not None:
|
604
|
+
classes = self._sklearn_object.classes_
|
605
|
+
if isinstance(classes, numpy.ndarray):
|
606
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
607
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
608
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
609
|
+
output_cols = []
|
610
|
+
for i, cl in enumerate(classes):
|
611
|
+
# For binary classification, there is only one output column for each class
|
612
|
+
# ndarray as the two classes are complementary.
|
613
|
+
if len(cl) == 2:
|
614
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
615
|
+
else:
|
616
|
+
output_cols.extend([
|
617
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
618
|
+
])
|
619
|
+
else:
|
620
|
+
output_cols = []
|
621
|
+
|
622
|
+
# Make sure column names are valid snowflake identifiers.
|
623
|
+
assert output_cols is not None # Make MyPy happy
|
624
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
625
|
+
|
626
|
+
return rv
|
627
|
+
|
628
|
+
def _align_expected_output_names(
|
629
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
630
|
+
) -> List[str]:
|
631
|
+
# in case the inferred output column names dimension is different
|
632
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
633
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
634
|
+
output_df_columns = list(output_df_pd.columns)
|
635
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
636
|
+
if self.sample_weight_col:
|
637
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
638
|
+
# if the dimension of inferred output column names is correct; use it
|
639
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
640
|
+
return expected_output_cols_list
|
641
|
+
# otherwise, use the sklearn estimator's output
|
642
|
+
else:
|
643
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
578
644
|
|
579
645
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
580
646
|
@telemetry.send_api_usage_telemetry(
|
@@ -606,24 +672,26 @@ class ElasticNet(BaseTransformer):
|
|
606
672
|
# are specific to the type of dataset used.
|
607
673
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
608
674
|
|
675
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
676
|
+
|
609
677
|
if isinstance(dataset, DataFrame):
|
610
|
-
self.
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
678
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
679
|
+
self._deps = self._get_dependencies()
|
680
|
+
assert isinstance(
|
681
|
+
dataset._session, Session
|
682
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
615
683
|
transform_kwargs = dict(
|
616
684
|
session=dataset._session,
|
617
685
|
dependencies=self._deps,
|
618
|
-
drop_input_cols
|
686
|
+
drop_input_cols=self._drop_input_cols,
|
619
687
|
expected_output_cols_type="float",
|
620
688
|
)
|
689
|
+
expected_output_cols = self._align_expected_output_names(
|
690
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
691
|
+
)
|
621
692
|
|
622
693
|
elif isinstance(dataset, pd.DataFrame):
|
623
|
-
transform_kwargs = dict(
|
624
|
-
snowpark_input_cols = self._snowpark_cols,
|
625
|
-
drop_input_cols = self._drop_input_cols
|
626
|
-
)
|
694
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
627
695
|
|
628
696
|
transform_handlers = ModelTransformerBuilder.build(
|
629
697
|
dataset=dataset,
|
@@ -635,7 +703,7 @@ class ElasticNet(BaseTransformer):
|
|
635
703
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
636
704
|
inference_method=inference_method,
|
637
705
|
input_cols=self.input_cols,
|
638
|
-
expected_output_cols=
|
706
|
+
expected_output_cols=expected_output_cols,
|
639
707
|
**transform_kwargs
|
640
708
|
)
|
641
709
|
return output_df
|
@@ -665,29 +733,30 @@ class ElasticNet(BaseTransformer):
|
|
665
733
|
Output dataset with log probability of the sample for each class in the model.
|
666
734
|
"""
|
667
735
|
super()._check_dataset_type(dataset)
|
668
|
-
inference_method="predict_log_proba"
|
736
|
+
inference_method = "predict_log_proba"
|
737
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
669
738
|
|
670
739
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
671
740
|
# are specific to the type of dataset used.
|
672
741
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
673
742
|
|
674
743
|
if isinstance(dataset, DataFrame):
|
675
|
-
self.
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
744
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
745
|
+
self._deps = self._get_dependencies()
|
746
|
+
assert isinstance(
|
747
|
+
dataset._session, Session
|
748
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
680
749
|
transform_kwargs = dict(
|
681
750
|
session=dataset._session,
|
682
751
|
dependencies=self._deps,
|
683
|
-
drop_input_cols
|
752
|
+
drop_input_cols=self._drop_input_cols,
|
684
753
|
expected_output_cols_type="float",
|
685
754
|
)
|
755
|
+
expected_output_cols = self._align_expected_output_names(
|
756
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
757
|
+
)
|
686
758
|
elif isinstance(dataset, pd.DataFrame):
|
687
|
-
transform_kwargs = dict(
|
688
|
-
snowpark_input_cols = self._snowpark_cols,
|
689
|
-
drop_input_cols = self._drop_input_cols
|
690
|
-
)
|
759
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
691
760
|
|
692
761
|
transform_handlers = ModelTransformerBuilder.build(
|
693
762
|
dataset=dataset,
|
@@ -700,7 +769,7 @@ class ElasticNet(BaseTransformer):
|
|
700
769
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
701
770
|
inference_method=inference_method,
|
702
771
|
input_cols=self.input_cols,
|
703
|
-
expected_output_cols=
|
772
|
+
expected_output_cols=expected_output_cols,
|
704
773
|
**transform_kwargs
|
705
774
|
)
|
706
775
|
return output_df
|
@@ -726,30 +795,32 @@ class ElasticNet(BaseTransformer):
|
|
726
795
|
Output dataset with results of the decision function for the samples in input dataset.
|
727
796
|
"""
|
728
797
|
super()._check_dataset_type(dataset)
|
729
|
-
inference_method="decision_function"
|
798
|
+
inference_method = "decision_function"
|
730
799
|
|
731
800
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
732
801
|
# are specific to the type of dataset used.
|
733
802
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
734
803
|
|
804
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
805
|
+
|
735
806
|
if isinstance(dataset, DataFrame):
|
736
|
-
self.
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
807
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
808
|
+
self._deps = self._get_dependencies()
|
809
|
+
assert isinstance(
|
810
|
+
dataset._session, Session
|
811
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
741
812
|
transform_kwargs = dict(
|
742
813
|
session=dataset._session,
|
743
814
|
dependencies=self._deps,
|
744
|
-
drop_input_cols
|
815
|
+
drop_input_cols=self._drop_input_cols,
|
745
816
|
expected_output_cols_type="float",
|
746
817
|
)
|
818
|
+
expected_output_cols = self._align_expected_output_names(
|
819
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
820
|
+
)
|
747
821
|
|
748
822
|
elif isinstance(dataset, pd.DataFrame):
|
749
|
-
transform_kwargs = dict(
|
750
|
-
snowpark_input_cols = self._snowpark_cols,
|
751
|
-
drop_input_cols = self._drop_input_cols
|
752
|
-
)
|
823
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
753
824
|
|
754
825
|
transform_handlers = ModelTransformerBuilder.build(
|
755
826
|
dataset=dataset,
|
@@ -762,7 +833,7 @@ class ElasticNet(BaseTransformer):
|
|
762
833
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
763
834
|
inference_method=inference_method,
|
764
835
|
input_cols=self.input_cols,
|
765
|
-
expected_output_cols=
|
836
|
+
expected_output_cols=expected_output_cols,
|
766
837
|
**transform_kwargs
|
767
838
|
)
|
768
839
|
return output_df
|
@@ -791,17 +862,17 @@ class ElasticNet(BaseTransformer):
|
|
791
862
|
Output dataset with probability of the sample for each class in the model.
|
792
863
|
"""
|
793
864
|
super()._check_dataset_type(dataset)
|
794
|
-
inference_method="score_samples"
|
865
|
+
inference_method = "score_samples"
|
795
866
|
|
796
867
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
797
868
|
# are specific to the type of dataset used.
|
798
869
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
799
870
|
|
871
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
872
|
+
|
800
873
|
if isinstance(dataset, DataFrame):
|
801
|
-
self.
|
802
|
-
|
803
|
-
inference_method=inference_method,
|
804
|
-
)
|
874
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
875
|
+
self._deps = self._get_dependencies()
|
805
876
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
806
877
|
transform_kwargs = dict(
|
807
878
|
session=dataset._session,
|
@@ -809,6 +880,9 @@ class ElasticNet(BaseTransformer):
|
|
809
880
|
drop_input_cols = self._drop_input_cols,
|
810
881
|
expected_output_cols_type="float",
|
811
882
|
)
|
883
|
+
expected_output_cols = self._align_expected_output_names(
|
884
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
885
|
+
)
|
812
886
|
|
813
887
|
elif isinstance(dataset, pd.DataFrame):
|
814
888
|
transform_kwargs = dict(
|
@@ -827,7 +901,7 @@ class ElasticNet(BaseTransformer):
|
|
827
901
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
828
902
|
inference_method=inference_method,
|
829
903
|
input_cols=self.input_cols,
|
830
|
-
expected_output_cols=
|
904
|
+
expected_output_cols=expected_output_cols,
|
831
905
|
**transform_kwargs
|
832
906
|
)
|
833
907
|
return output_df
|
@@ -862,17 +936,15 @@ class ElasticNet(BaseTransformer):
|
|
862
936
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
863
937
|
|
864
938
|
if isinstance(dataset, DataFrame):
|
865
|
-
self.
|
866
|
-
|
867
|
-
inference_method="score",
|
868
|
-
)
|
939
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
940
|
+
self._deps = self._get_dependencies()
|
869
941
|
selected_cols = self._get_active_columns()
|
870
942
|
if len(selected_cols) > 0:
|
871
943
|
dataset = dataset.select(selected_cols)
|
872
944
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
873
945
|
transform_kwargs = dict(
|
874
946
|
session=dataset._session,
|
875
|
-
dependencies=
|
947
|
+
dependencies=self._deps,
|
876
948
|
score_sproc_imports=['sklearn'],
|
877
949
|
)
|
878
950
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -937,11 +1009,8 @@ class ElasticNet(BaseTransformer):
|
|
937
1009
|
|
938
1010
|
if isinstance(dataset, DataFrame):
|
939
1011
|
|
940
|
-
self.
|
941
|
-
|
942
|
-
inference_method=inference_method,
|
943
|
-
|
944
|
-
)
|
1012
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1013
|
+
self._deps = self._get_dependencies()
|
945
1014
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
946
1015
|
transform_kwargs = dict(
|
947
1016
|
session = dataset._session,
|
@@ -974,50 +1043,84 @@ class ElasticNet(BaseTransformer):
|
|
974
1043
|
)
|
975
1044
|
return output_df
|
976
1045
|
|
1046
|
+
|
1047
|
+
|
1048
|
+
def to_sklearn(self) -> Any:
|
1049
|
+
"""Get sklearn.linear_model.ElasticNet object.
|
1050
|
+
"""
|
1051
|
+
if self._sklearn_object is None:
|
1052
|
+
self._sklearn_object = self._create_sklearn_object()
|
1053
|
+
return self._sklearn_object
|
1054
|
+
|
1055
|
+
def to_xgboost(self) -> Any:
|
1056
|
+
raise exceptions.SnowflakeMLException(
|
1057
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1058
|
+
original_exception=AttributeError(
|
1059
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1060
|
+
"to_xgboost()",
|
1061
|
+
"to_sklearn()"
|
1062
|
+
)
|
1063
|
+
),
|
1064
|
+
)
|
1065
|
+
|
1066
|
+
def to_lightgbm(self) -> Any:
|
1067
|
+
raise exceptions.SnowflakeMLException(
|
1068
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1069
|
+
original_exception=AttributeError(
|
1070
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1071
|
+
"to_lightgbm()",
|
1072
|
+
"to_sklearn()"
|
1073
|
+
)
|
1074
|
+
),
|
1075
|
+
)
|
1076
|
+
|
1077
|
+
def _get_dependencies(self) -> List[str]:
|
1078
|
+
return self._deps
|
1079
|
+
|
977
1080
|
|
978
|
-
def
|
1081
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
979
1082
|
self._model_signature_dict = dict()
|
980
1083
|
|
981
1084
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
982
1085
|
|
983
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1086
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
984
1087
|
outputs: List[BaseFeatureSpec] = []
|
985
1088
|
if hasattr(self, "predict"):
|
986
1089
|
# keep mypy happy
|
987
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1090
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
988
1091
|
# For classifier, the type of predict is the same as the type of label
|
989
|
-
if self._sklearn_object._estimator_type ==
|
990
|
-
|
1092
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1093
|
+
# label columns is the desired type for output
|
991
1094
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
992
1095
|
# rename the output columns
|
993
1096
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
994
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
995
|
-
|
996
|
-
|
1097
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1098
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1099
|
+
)
|
997
1100
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
998
1101
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
999
|
-
# Clusterer returns int64 cluster labels.
|
1102
|
+
# Clusterer returns int64 cluster labels.
|
1000
1103
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1001
1104
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1002
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1105
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1106
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1107
|
+
)
|
1108
|
+
|
1006
1109
|
# For regressor, the type of predict is float64
|
1007
|
-
elif self._sklearn_object._estimator_type ==
|
1110
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1008
1111
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1009
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1112
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1113
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1114
|
+
)
|
1115
|
+
|
1013
1116
|
for prob_func in PROB_FUNCTIONS:
|
1014
1117
|
if hasattr(self, prob_func):
|
1015
1118
|
output_cols_prefix: str = f"{prob_func}_"
|
1016
1119
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1017
1120
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1018
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1019
|
-
|
1020
|
-
|
1121
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1122
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1123
|
+
)
|
1021
1124
|
|
1022
1125
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1023
1126
|
items = list(self._model_signature_dict.items())
|
@@ -1030,10 +1133,10 @@ class ElasticNet(BaseTransformer):
|
|
1030
1133
|
"""Returns model signature of current class.
|
1031
1134
|
|
1032
1135
|
Raises:
|
1033
|
-
|
1136
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1034
1137
|
|
1035
1138
|
Returns:
|
1036
|
-
Dict
|
1139
|
+
Dict with each method and its input output signature
|
1037
1140
|
"""
|
1038
1141
|
if self._model_signature_dict is None:
|
1039
1142
|
raise exceptions.SnowflakeMLException(
|
@@ -1041,35 +1144,3 @@ class ElasticNet(BaseTransformer):
|
|
1041
1144
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1042
1145
|
)
|
1043
1146
|
return self._model_signature_dict
|
1044
|
-
|
1045
|
-
def to_sklearn(self) -> Any:
|
1046
|
-
"""Get sklearn.linear_model.ElasticNet object.
|
1047
|
-
"""
|
1048
|
-
if self._sklearn_object is None:
|
1049
|
-
self._sklearn_object = self._create_sklearn_object()
|
1050
|
-
return self._sklearn_object
|
1051
|
-
|
1052
|
-
def to_xgboost(self) -> Any:
|
1053
|
-
raise exceptions.SnowflakeMLException(
|
1054
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1055
|
-
original_exception=AttributeError(
|
1056
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1057
|
-
"to_xgboost()",
|
1058
|
-
"to_sklearn()"
|
1059
|
-
)
|
1060
|
-
),
|
1061
|
-
)
|
1062
|
-
|
1063
|
-
def to_lightgbm(self) -> Any:
|
1064
|
-
raise exceptions.SnowflakeMLException(
|
1065
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1066
|
-
original_exception=AttributeError(
|
1067
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1068
|
-
"to_lightgbm()",
|
1069
|
-
"to_sklearn()"
|
1070
|
-
)
|
1071
|
-
),
|
1072
|
-
)
|
1073
|
-
|
1074
|
-
def _get_dependencies(self) -> List[str]:
|
1075
|
-
return self._deps
|