snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class LassoCV(BaseTransformer):
|
71
64
|
r"""Lasso linear model with iterative fitting along a regularization path
|
72
65
|
For more details on this class, see [sklearn.linear_model.LassoCV]
|
@@ -290,12 +283,7 @@ class LassoCV(BaseTransformer):
|
|
290
283
|
)
|
291
284
|
return selected_cols
|
292
285
|
|
293
|
-
|
294
|
-
project=_PROJECT,
|
295
|
-
subproject=_SUBPROJECT,
|
296
|
-
custom_tags=dict([("autogen", True)]),
|
297
|
-
)
|
298
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LassoCV":
|
286
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LassoCV":
|
299
287
|
"""Fit linear model with coordinate descent
|
300
288
|
For more details on this function, see [sklearn.linear_model.LassoCV.fit]
|
301
289
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html#sklearn.linear_model.LassoCV.fit)
|
@@ -322,12 +310,14 @@ class LassoCV(BaseTransformer):
|
|
322
310
|
|
323
311
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
324
312
|
|
325
|
-
|
313
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
326
314
|
if SNOWML_SPROC_ENV in os.environ:
|
327
315
|
statement_params = telemetry.get_function_usage_statement_params(
|
328
316
|
project=_PROJECT,
|
329
317
|
subproject=_SUBPROJECT,
|
330
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
318
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
319
|
+
inspect.currentframe(), LassoCV.__class__.__name__
|
320
|
+
),
|
331
321
|
api_calls=[Session.call],
|
332
322
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
333
323
|
)
|
@@ -348,27 +338,24 @@ class LassoCV(BaseTransformer):
|
|
348
338
|
)
|
349
339
|
self._sklearn_object = model_trainer.train()
|
350
340
|
self._is_fitted = True
|
351
|
-
self.
|
341
|
+
self._generate_model_signatures(dataset)
|
352
342
|
return self
|
353
343
|
|
354
344
|
def _batch_inference_validate_snowpark(
|
355
345
|
self,
|
356
346
|
dataset: DataFrame,
|
357
347
|
inference_method: str,
|
358
|
-
) ->
|
359
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
360
|
-
return the available package that exists in the snowflake anaconda channel
|
348
|
+
) -> None:
|
349
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
361
350
|
|
362
351
|
Args:
|
363
352
|
dataset: snowpark dataframe
|
364
353
|
inference_method: the inference method such as predict, score...
|
365
|
-
|
354
|
+
|
366
355
|
Raises:
|
367
356
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
368
357
|
SnowflakeMLException: If the session is None, raise error
|
369
358
|
|
370
|
-
Returns:
|
371
|
-
A list of available package that exists in the snowflake anaconda channel
|
372
359
|
"""
|
373
360
|
if not self._is_fitted:
|
374
361
|
raise exceptions.SnowflakeMLException(
|
@@ -386,9 +373,7 @@ class LassoCV(BaseTransformer):
|
|
386
373
|
"Session must not specified for snowpark dataset."
|
387
374
|
),
|
388
375
|
)
|
389
|
-
|
390
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
391
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
376
|
+
|
392
377
|
|
393
378
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
394
379
|
@telemetry.send_api_usage_telemetry(
|
@@ -424,7 +409,9 @@ class LassoCV(BaseTransformer):
|
|
424
409
|
# when it is classifier, infer the datatype from label columns
|
425
410
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
426
411
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
427
|
-
label_cols_signatures = [
|
412
|
+
label_cols_signatures = [
|
413
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
414
|
+
]
|
428
415
|
if len(label_cols_signatures) == 0:
|
429
416
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
430
417
|
raise exceptions.SnowflakeMLException(
|
@@ -432,25 +419,23 @@ class LassoCV(BaseTransformer):
|
|
432
419
|
original_exception=ValueError(error_str),
|
433
420
|
)
|
434
421
|
|
435
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
436
|
-
label_cols_signatures[0].as_snowpark_type()
|
437
|
-
)
|
422
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
438
423
|
|
439
|
-
self.
|
440
|
-
|
424
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
425
|
+
self._deps = self._get_dependencies()
|
426
|
+
assert isinstance(
|
427
|
+
dataset._session, Session
|
428
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
441
429
|
|
442
430
|
transform_kwargs = dict(
|
443
|
-
session
|
444
|
-
dependencies
|
445
|
-
drop_input_cols
|
446
|
-
expected_output_cols_type
|
431
|
+
session=dataset._session,
|
432
|
+
dependencies=self._deps,
|
433
|
+
drop_input_cols=self._drop_input_cols,
|
434
|
+
expected_output_cols_type=expected_type_inferred,
|
447
435
|
)
|
448
436
|
|
449
437
|
elif isinstance(dataset, pd.DataFrame):
|
450
|
-
transform_kwargs = dict(
|
451
|
-
snowpark_input_cols = self._snowpark_cols,
|
452
|
-
drop_input_cols = self._drop_input_cols
|
453
|
-
)
|
438
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
454
439
|
|
455
440
|
transform_handlers = ModelTransformerBuilder.build(
|
456
441
|
dataset=dataset,
|
@@ -490,7 +475,7 @@ class LassoCV(BaseTransformer):
|
|
490
475
|
Transformed dataset.
|
491
476
|
"""
|
492
477
|
super()._check_dataset_type(dataset)
|
493
|
-
inference_method="transform"
|
478
|
+
inference_method = "transform"
|
494
479
|
|
495
480
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
496
481
|
# are specific to the type of dataset used.
|
@@ -520,24 +505,19 @@ class LassoCV(BaseTransformer):
|
|
520
505
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
521
506
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
522
507
|
|
523
|
-
self.
|
524
|
-
|
525
|
-
inference_method=inference_method,
|
526
|
-
)
|
508
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
509
|
+
self._deps = self._get_dependencies()
|
527
510
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
528
511
|
|
529
512
|
transform_kwargs = dict(
|
530
|
-
session
|
531
|
-
dependencies
|
532
|
-
drop_input_cols
|
533
|
-
expected_output_cols_type
|
513
|
+
session=dataset._session,
|
514
|
+
dependencies=self._deps,
|
515
|
+
drop_input_cols=self._drop_input_cols,
|
516
|
+
expected_output_cols_type=expected_dtype,
|
534
517
|
)
|
535
518
|
|
536
519
|
elif isinstance(dataset, pd.DataFrame):
|
537
|
-
transform_kwargs = dict(
|
538
|
-
snowpark_input_cols = self._snowpark_cols,
|
539
|
-
drop_input_cols = self._drop_input_cols
|
540
|
-
)
|
520
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
541
521
|
|
542
522
|
transform_handlers = ModelTransformerBuilder.build(
|
543
523
|
dataset=dataset,
|
@@ -556,7 +536,11 @@ class LassoCV(BaseTransformer):
|
|
556
536
|
return output_df
|
557
537
|
|
558
538
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
559
|
-
def fit_predict(
|
539
|
+
def fit_predict(
|
540
|
+
self,
|
541
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
542
|
+
output_cols_prefix: str = "fit_predict_",
|
543
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
560
544
|
""" Method not supported for this class.
|
561
545
|
|
562
546
|
|
@@ -581,22 +565,104 @@ class LassoCV(BaseTransformer):
|
|
581
565
|
)
|
582
566
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
583
567
|
drop_input_cols=self._drop_input_cols,
|
584
|
-
expected_output_cols_list=
|
568
|
+
expected_output_cols_list=(
|
569
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
570
|
+
),
|
585
571
|
)
|
586
572
|
self._sklearn_object = fitted_estimator
|
587
573
|
self._is_fitted = True
|
588
574
|
return output_result
|
589
575
|
|
576
|
+
|
577
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
578
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
579
|
+
""" Method not supported for this class.
|
580
|
+
|
590
581
|
|
591
|
-
|
592
|
-
|
593
|
-
|
582
|
+
Raises:
|
583
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
584
|
+
|
585
|
+
Args:
|
586
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
587
|
+
Snowpark or Pandas DataFrame.
|
588
|
+
output_cols_prefix: Prefix for the response columns
|
594
589
|
Returns:
|
595
590
|
Transformed dataset.
|
596
591
|
"""
|
597
|
-
self.
|
598
|
-
|
599
|
-
|
592
|
+
self._infer_input_output_cols(dataset)
|
593
|
+
super()._check_dataset_type(dataset)
|
594
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
595
|
+
estimator=self._sklearn_object,
|
596
|
+
dataset=dataset,
|
597
|
+
input_cols=self.input_cols,
|
598
|
+
label_cols=self.label_cols,
|
599
|
+
sample_weight_col=self.sample_weight_col,
|
600
|
+
autogenerated=self._autogenerated,
|
601
|
+
subproject=_SUBPROJECT,
|
602
|
+
)
|
603
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
604
|
+
drop_input_cols=self._drop_input_cols,
|
605
|
+
expected_output_cols_list=self.output_cols,
|
606
|
+
)
|
607
|
+
self._sklearn_object = fitted_estimator
|
608
|
+
self._is_fitted = True
|
609
|
+
return output_result
|
610
|
+
|
611
|
+
|
612
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
613
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
614
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
615
|
+
"""
|
616
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
617
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
618
|
+
if output_cols:
|
619
|
+
output_cols = [
|
620
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
621
|
+
for c in output_cols
|
622
|
+
]
|
623
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
624
|
+
output_cols = [output_cols_prefix]
|
625
|
+
elif self._sklearn_object is not None:
|
626
|
+
classes = self._sklearn_object.classes_
|
627
|
+
if isinstance(classes, numpy.ndarray):
|
628
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
629
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
630
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
631
|
+
output_cols = []
|
632
|
+
for i, cl in enumerate(classes):
|
633
|
+
# For binary classification, there is only one output column for each class
|
634
|
+
# ndarray as the two classes are complementary.
|
635
|
+
if len(cl) == 2:
|
636
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
637
|
+
else:
|
638
|
+
output_cols.extend([
|
639
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
640
|
+
])
|
641
|
+
else:
|
642
|
+
output_cols = []
|
643
|
+
|
644
|
+
# Make sure column names are valid snowflake identifiers.
|
645
|
+
assert output_cols is not None # Make MyPy happy
|
646
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
647
|
+
|
648
|
+
return rv
|
649
|
+
|
650
|
+
def _align_expected_output_names(
|
651
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
652
|
+
) -> List[str]:
|
653
|
+
# in case the inferred output column names dimension is different
|
654
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
655
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
656
|
+
output_df_columns = list(output_df_pd.columns)
|
657
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
658
|
+
if self.sample_weight_col:
|
659
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
660
|
+
# if the dimension of inferred output column names is correct; use it
|
661
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
662
|
+
return expected_output_cols_list
|
663
|
+
# otherwise, use the sklearn estimator's output
|
664
|
+
else:
|
665
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
600
666
|
|
601
667
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
602
668
|
@telemetry.send_api_usage_telemetry(
|
@@ -628,24 +694,26 @@ class LassoCV(BaseTransformer):
|
|
628
694
|
# are specific to the type of dataset used.
|
629
695
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
630
696
|
|
697
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
698
|
+
|
631
699
|
if isinstance(dataset, DataFrame):
|
632
|
-
self.
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
700
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
701
|
+
self._deps = self._get_dependencies()
|
702
|
+
assert isinstance(
|
703
|
+
dataset._session, Session
|
704
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
637
705
|
transform_kwargs = dict(
|
638
706
|
session=dataset._session,
|
639
707
|
dependencies=self._deps,
|
640
|
-
drop_input_cols
|
708
|
+
drop_input_cols=self._drop_input_cols,
|
641
709
|
expected_output_cols_type="float",
|
642
710
|
)
|
711
|
+
expected_output_cols = self._align_expected_output_names(
|
712
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
713
|
+
)
|
643
714
|
|
644
715
|
elif isinstance(dataset, pd.DataFrame):
|
645
|
-
transform_kwargs = dict(
|
646
|
-
snowpark_input_cols = self._snowpark_cols,
|
647
|
-
drop_input_cols = self._drop_input_cols
|
648
|
-
)
|
716
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
649
717
|
|
650
718
|
transform_handlers = ModelTransformerBuilder.build(
|
651
719
|
dataset=dataset,
|
@@ -657,7 +725,7 @@ class LassoCV(BaseTransformer):
|
|
657
725
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
658
726
|
inference_method=inference_method,
|
659
727
|
input_cols=self.input_cols,
|
660
|
-
expected_output_cols=
|
728
|
+
expected_output_cols=expected_output_cols,
|
661
729
|
**transform_kwargs
|
662
730
|
)
|
663
731
|
return output_df
|
@@ -687,29 +755,30 @@ class LassoCV(BaseTransformer):
|
|
687
755
|
Output dataset with log probability of the sample for each class in the model.
|
688
756
|
"""
|
689
757
|
super()._check_dataset_type(dataset)
|
690
|
-
inference_method="predict_log_proba"
|
758
|
+
inference_method = "predict_log_proba"
|
759
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
691
760
|
|
692
761
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
693
762
|
# are specific to the type of dataset used.
|
694
763
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
695
764
|
|
696
765
|
if isinstance(dataset, DataFrame):
|
697
|
-
self.
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
766
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
767
|
+
self._deps = self._get_dependencies()
|
768
|
+
assert isinstance(
|
769
|
+
dataset._session, Session
|
770
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
702
771
|
transform_kwargs = dict(
|
703
772
|
session=dataset._session,
|
704
773
|
dependencies=self._deps,
|
705
|
-
drop_input_cols
|
774
|
+
drop_input_cols=self._drop_input_cols,
|
706
775
|
expected_output_cols_type="float",
|
707
776
|
)
|
777
|
+
expected_output_cols = self._align_expected_output_names(
|
778
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
779
|
+
)
|
708
780
|
elif isinstance(dataset, pd.DataFrame):
|
709
|
-
transform_kwargs = dict(
|
710
|
-
snowpark_input_cols = self._snowpark_cols,
|
711
|
-
drop_input_cols = self._drop_input_cols
|
712
|
-
)
|
781
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
713
782
|
|
714
783
|
transform_handlers = ModelTransformerBuilder.build(
|
715
784
|
dataset=dataset,
|
@@ -722,7 +791,7 @@ class LassoCV(BaseTransformer):
|
|
722
791
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
723
792
|
inference_method=inference_method,
|
724
793
|
input_cols=self.input_cols,
|
725
|
-
expected_output_cols=
|
794
|
+
expected_output_cols=expected_output_cols,
|
726
795
|
**transform_kwargs
|
727
796
|
)
|
728
797
|
return output_df
|
@@ -748,30 +817,32 @@ class LassoCV(BaseTransformer):
|
|
748
817
|
Output dataset with results of the decision function for the samples in input dataset.
|
749
818
|
"""
|
750
819
|
super()._check_dataset_type(dataset)
|
751
|
-
inference_method="decision_function"
|
820
|
+
inference_method = "decision_function"
|
752
821
|
|
753
822
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
754
823
|
# are specific to the type of dataset used.
|
755
824
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
756
825
|
|
826
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
827
|
+
|
757
828
|
if isinstance(dataset, DataFrame):
|
758
|
-
self.
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
829
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
830
|
+
self._deps = self._get_dependencies()
|
831
|
+
assert isinstance(
|
832
|
+
dataset._session, Session
|
833
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
763
834
|
transform_kwargs = dict(
|
764
835
|
session=dataset._session,
|
765
836
|
dependencies=self._deps,
|
766
|
-
drop_input_cols
|
837
|
+
drop_input_cols=self._drop_input_cols,
|
767
838
|
expected_output_cols_type="float",
|
768
839
|
)
|
840
|
+
expected_output_cols = self._align_expected_output_names(
|
841
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
842
|
+
)
|
769
843
|
|
770
844
|
elif isinstance(dataset, pd.DataFrame):
|
771
|
-
transform_kwargs = dict(
|
772
|
-
snowpark_input_cols = self._snowpark_cols,
|
773
|
-
drop_input_cols = self._drop_input_cols
|
774
|
-
)
|
845
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
775
846
|
|
776
847
|
transform_handlers = ModelTransformerBuilder.build(
|
777
848
|
dataset=dataset,
|
@@ -784,7 +855,7 @@ class LassoCV(BaseTransformer):
|
|
784
855
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
785
856
|
inference_method=inference_method,
|
786
857
|
input_cols=self.input_cols,
|
787
|
-
expected_output_cols=
|
858
|
+
expected_output_cols=expected_output_cols,
|
788
859
|
**transform_kwargs
|
789
860
|
)
|
790
861
|
return output_df
|
@@ -813,17 +884,17 @@ class LassoCV(BaseTransformer):
|
|
813
884
|
Output dataset with probability of the sample for each class in the model.
|
814
885
|
"""
|
815
886
|
super()._check_dataset_type(dataset)
|
816
|
-
inference_method="score_samples"
|
887
|
+
inference_method = "score_samples"
|
817
888
|
|
818
889
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
819
890
|
# are specific to the type of dataset used.
|
820
891
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
821
892
|
|
893
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
894
|
+
|
822
895
|
if isinstance(dataset, DataFrame):
|
823
|
-
self.
|
824
|
-
|
825
|
-
inference_method=inference_method,
|
826
|
-
)
|
896
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
897
|
+
self._deps = self._get_dependencies()
|
827
898
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
828
899
|
transform_kwargs = dict(
|
829
900
|
session=dataset._session,
|
@@ -831,6 +902,9 @@ class LassoCV(BaseTransformer):
|
|
831
902
|
drop_input_cols = self._drop_input_cols,
|
832
903
|
expected_output_cols_type="float",
|
833
904
|
)
|
905
|
+
expected_output_cols = self._align_expected_output_names(
|
906
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
907
|
+
)
|
834
908
|
|
835
909
|
elif isinstance(dataset, pd.DataFrame):
|
836
910
|
transform_kwargs = dict(
|
@@ -849,7 +923,7 @@ class LassoCV(BaseTransformer):
|
|
849
923
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
850
924
|
inference_method=inference_method,
|
851
925
|
input_cols=self.input_cols,
|
852
|
-
expected_output_cols=
|
926
|
+
expected_output_cols=expected_output_cols,
|
853
927
|
**transform_kwargs
|
854
928
|
)
|
855
929
|
return output_df
|
@@ -884,17 +958,15 @@ class LassoCV(BaseTransformer):
|
|
884
958
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
885
959
|
|
886
960
|
if isinstance(dataset, DataFrame):
|
887
|
-
self.
|
888
|
-
|
889
|
-
inference_method="score",
|
890
|
-
)
|
961
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
962
|
+
self._deps = self._get_dependencies()
|
891
963
|
selected_cols = self._get_active_columns()
|
892
964
|
if len(selected_cols) > 0:
|
893
965
|
dataset = dataset.select(selected_cols)
|
894
966
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
895
967
|
transform_kwargs = dict(
|
896
968
|
session=dataset._session,
|
897
|
-
dependencies=
|
969
|
+
dependencies=self._deps,
|
898
970
|
score_sproc_imports=['sklearn'],
|
899
971
|
)
|
900
972
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -959,11 +1031,8 @@ class LassoCV(BaseTransformer):
|
|
959
1031
|
|
960
1032
|
if isinstance(dataset, DataFrame):
|
961
1033
|
|
962
|
-
self.
|
963
|
-
|
964
|
-
inference_method=inference_method,
|
965
|
-
|
966
|
-
)
|
1034
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1035
|
+
self._deps = self._get_dependencies()
|
967
1036
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
968
1037
|
transform_kwargs = dict(
|
969
1038
|
session = dataset._session,
|
@@ -996,50 +1065,84 @@ class LassoCV(BaseTransformer):
|
|
996
1065
|
)
|
997
1066
|
return output_df
|
998
1067
|
|
1068
|
+
|
1069
|
+
|
1070
|
+
def to_sklearn(self) -> Any:
|
1071
|
+
"""Get sklearn.linear_model.LassoCV object.
|
1072
|
+
"""
|
1073
|
+
if self._sklearn_object is None:
|
1074
|
+
self._sklearn_object = self._create_sklearn_object()
|
1075
|
+
return self._sklearn_object
|
1076
|
+
|
1077
|
+
def to_xgboost(self) -> Any:
|
1078
|
+
raise exceptions.SnowflakeMLException(
|
1079
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1080
|
+
original_exception=AttributeError(
|
1081
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1082
|
+
"to_xgboost()",
|
1083
|
+
"to_sklearn()"
|
1084
|
+
)
|
1085
|
+
),
|
1086
|
+
)
|
1087
|
+
|
1088
|
+
def to_lightgbm(self) -> Any:
|
1089
|
+
raise exceptions.SnowflakeMLException(
|
1090
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1091
|
+
original_exception=AttributeError(
|
1092
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1093
|
+
"to_lightgbm()",
|
1094
|
+
"to_sklearn()"
|
1095
|
+
)
|
1096
|
+
),
|
1097
|
+
)
|
1098
|
+
|
1099
|
+
def _get_dependencies(self) -> List[str]:
|
1100
|
+
return self._deps
|
1101
|
+
|
999
1102
|
|
1000
|
-
def
|
1103
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1001
1104
|
self._model_signature_dict = dict()
|
1002
1105
|
|
1003
1106
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1004
1107
|
|
1005
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1108
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1006
1109
|
outputs: List[BaseFeatureSpec] = []
|
1007
1110
|
if hasattr(self, "predict"):
|
1008
1111
|
# keep mypy happy
|
1009
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1112
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1010
1113
|
# For classifier, the type of predict is the same as the type of label
|
1011
|
-
if self._sklearn_object._estimator_type ==
|
1012
|
-
|
1114
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1115
|
+
# label columns is the desired type for output
|
1013
1116
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1014
1117
|
# rename the output columns
|
1015
1118
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1016
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1017
|
-
|
1018
|
-
|
1119
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1120
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1121
|
+
)
|
1019
1122
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1020
1123
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1021
|
-
# Clusterer returns int64 cluster labels.
|
1124
|
+
# Clusterer returns int64 cluster labels.
|
1022
1125
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1023
1126
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1024
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1127
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1128
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1129
|
+
)
|
1130
|
+
|
1028
1131
|
# For regressor, the type of predict is float64
|
1029
|
-
elif self._sklearn_object._estimator_type ==
|
1132
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1030
1133
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1031
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1134
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1135
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1136
|
+
)
|
1137
|
+
|
1035
1138
|
for prob_func in PROB_FUNCTIONS:
|
1036
1139
|
if hasattr(self, prob_func):
|
1037
1140
|
output_cols_prefix: str = f"{prob_func}_"
|
1038
1141
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1039
1142
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1040
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1041
|
-
|
1042
|
-
|
1143
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1144
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1145
|
+
)
|
1043
1146
|
|
1044
1147
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1045
1148
|
items = list(self._model_signature_dict.items())
|
@@ -1052,10 +1155,10 @@ class LassoCV(BaseTransformer):
|
|
1052
1155
|
"""Returns model signature of current class.
|
1053
1156
|
|
1054
1157
|
Raises:
|
1055
|
-
|
1158
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1056
1159
|
|
1057
1160
|
Returns:
|
1058
|
-
Dict
|
1161
|
+
Dict with each method and its input output signature
|
1059
1162
|
"""
|
1060
1163
|
if self._model_signature_dict is None:
|
1061
1164
|
raise exceptions.SnowflakeMLException(
|
@@ -1063,35 +1166,3 @@ class LassoCV(BaseTransformer):
|
|
1063
1166
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1064
1167
|
)
|
1065
1168
|
return self._model_signature_dict
|
1066
|
-
|
1067
|
-
def to_sklearn(self) -> Any:
|
1068
|
-
"""Get sklearn.linear_model.LassoCV object.
|
1069
|
-
"""
|
1070
|
-
if self._sklearn_object is None:
|
1071
|
-
self._sklearn_object = self._create_sklearn_object()
|
1072
|
-
return self._sklearn_object
|
1073
|
-
|
1074
|
-
def to_xgboost(self) -> Any:
|
1075
|
-
raise exceptions.SnowflakeMLException(
|
1076
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1077
|
-
original_exception=AttributeError(
|
1078
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1079
|
-
"to_xgboost()",
|
1080
|
-
"to_sklearn()"
|
1081
|
-
)
|
1082
|
-
),
|
1083
|
-
)
|
1084
|
-
|
1085
|
-
def to_lightgbm(self) -> Any:
|
1086
|
-
raise exceptions.SnowflakeMLException(
|
1087
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1088
|
-
original_exception=AttributeError(
|
1089
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1090
|
-
"to_lightgbm()",
|
1091
|
-
"to_sklearn()"
|
1092
|
-
)
|
1093
|
-
),
|
1094
|
-
)
|
1095
|
-
|
1096
|
-
def _get_dependencies(self) -> List[str]:
|
1097
|
-
return self._deps
|