snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class LogisticRegression(BaseTransformer):
|
71
64
|
r"""Logistic Regression (aka logit, MaxEnt) classifier
|
72
65
|
For more details on this class, see [sklearn.linear_model.LogisticRegression]
|
@@ -333,12 +326,7 @@ class LogisticRegression(BaseTransformer):
|
|
333
326
|
)
|
334
327
|
return selected_cols
|
335
328
|
|
336
|
-
|
337
|
-
project=_PROJECT,
|
338
|
-
subproject=_SUBPROJECT,
|
339
|
-
custom_tags=dict([("autogen", True)]),
|
340
|
-
)
|
341
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LogisticRegression":
|
329
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LogisticRegression":
|
342
330
|
"""Fit the model according to the given training data
|
343
331
|
For more details on this function, see [sklearn.linear_model.LogisticRegression.fit]
|
344
332
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression.fit)
|
@@ -365,12 +353,14 @@ class LogisticRegression(BaseTransformer):
|
|
365
353
|
|
366
354
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
367
355
|
|
368
|
-
|
356
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
369
357
|
if SNOWML_SPROC_ENV in os.environ:
|
370
358
|
statement_params = telemetry.get_function_usage_statement_params(
|
371
359
|
project=_PROJECT,
|
372
360
|
subproject=_SUBPROJECT,
|
373
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
361
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
362
|
+
inspect.currentframe(), LogisticRegression.__class__.__name__
|
363
|
+
),
|
374
364
|
api_calls=[Session.call],
|
375
365
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
376
366
|
)
|
@@ -391,27 +381,24 @@ class LogisticRegression(BaseTransformer):
|
|
391
381
|
)
|
392
382
|
self._sklearn_object = model_trainer.train()
|
393
383
|
self._is_fitted = True
|
394
|
-
self.
|
384
|
+
self._generate_model_signatures(dataset)
|
395
385
|
return self
|
396
386
|
|
397
387
|
def _batch_inference_validate_snowpark(
|
398
388
|
self,
|
399
389
|
dataset: DataFrame,
|
400
390
|
inference_method: str,
|
401
|
-
) ->
|
402
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
403
|
-
return the available package that exists in the snowflake anaconda channel
|
391
|
+
) -> None:
|
392
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
404
393
|
|
405
394
|
Args:
|
406
395
|
dataset: snowpark dataframe
|
407
396
|
inference_method: the inference method such as predict, score...
|
408
|
-
|
397
|
+
|
409
398
|
Raises:
|
410
399
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
411
400
|
SnowflakeMLException: If the session is None, raise error
|
412
401
|
|
413
|
-
Returns:
|
414
|
-
A list of available package that exists in the snowflake anaconda channel
|
415
402
|
"""
|
416
403
|
if not self._is_fitted:
|
417
404
|
raise exceptions.SnowflakeMLException(
|
@@ -429,9 +416,7 @@ class LogisticRegression(BaseTransformer):
|
|
429
416
|
"Session must not specified for snowpark dataset."
|
430
417
|
),
|
431
418
|
)
|
432
|
-
|
433
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
434
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
419
|
+
|
435
420
|
|
436
421
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
437
422
|
@telemetry.send_api_usage_telemetry(
|
@@ -467,7 +452,9 @@ class LogisticRegression(BaseTransformer):
|
|
467
452
|
# when it is classifier, infer the datatype from label columns
|
468
453
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
469
454
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
470
|
-
label_cols_signatures = [
|
455
|
+
label_cols_signatures = [
|
456
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
457
|
+
]
|
471
458
|
if len(label_cols_signatures) == 0:
|
472
459
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
473
460
|
raise exceptions.SnowflakeMLException(
|
@@ -475,25 +462,23 @@ class LogisticRegression(BaseTransformer):
|
|
475
462
|
original_exception=ValueError(error_str),
|
476
463
|
)
|
477
464
|
|
478
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
479
|
-
label_cols_signatures[0].as_snowpark_type()
|
480
|
-
)
|
465
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
481
466
|
|
482
|
-
self.
|
483
|
-
|
467
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
468
|
+
self._deps = self._get_dependencies()
|
469
|
+
assert isinstance(
|
470
|
+
dataset._session, Session
|
471
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
484
472
|
|
485
473
|
transform_kwargs = dict(
|
486
|
-
session
|
487
|
-
dependencies
|
488
|
-
drop_input_cols
|
489
|
-
expected_output_cols_type
|
474
|
+
session=dataset._session,
|
475
|
+
dependencies=self._deps,
|
476
|
+
drop_input_cols=self._drop_input_cols,
|
477
|
+
expected_output_cols_type=expected_type_inferred,
|
490
478
|
)
|
491
479
|
|
492
480
|
elif isinstance(dataset, pd.DataFrame):
|
493
|
-
transform_kwargs = dict(
|
494
|
-
snowpark_input_cols = self._snowpark_cols,
|
495
|
-
drop_input_cols = self._drop_input_cols
|
496
|
-
)
|
481
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
497
482
|
|
498
483
|
transform_handlers = ModelTransformerBuilder.build(
|
499
484
|
dataset=dataset,
|
@@ -533,7 +518,7 @@ class LogisticRegression(BaseTransformer):
|
|
533
518
|
Transformed dataset.
|
534
519
|
"""
|
535
520
|
super()._check_dataset_type(dataset)
|
536
|
-
inference_method="transform"
|
521
|
+
inference_method = "transform"
|
537
522
|
|
538
523
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
539
524
|
# are specific to the type of dataset used.
|
@@ -563,24 +548,19 @@ class LogisticRegression(BaseTransformer):
|
|
563
548
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
564
549
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
565
550
|
|
566
|
-
self.
|
567
|
-
|
568
|
-
inference_method=inference_method,
|
569
|
-
)
|
551
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
552
|
+
self._deps = self._get_dependencies()
|
570
553
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
571
554
|
|
572
555
|
transform_kwargs = dict(
|
573
|
-
session
|
574
|
-
dependencies
|
575
|
-
drop_input_cols
|
576
|
-
expected_output_cols_type
|
556
|
+
session=dataset._session,
|
557
|
+
dependencies=self._deps,
|
558
|
+
drop_input_cols=self._drop_input_cols,
|
559
|
+
expected_output_cols_type=expected_dtype,
|
577
560
|
)
|
578
561
|
|
579
562
|
elif isinstance(dataset, pd.DataFrame):
|
580
|
-
transform_kwargs = dict(
|
581
|
-
snowpark_input_cols = self._snowpark_cols,
|
582
|
-
drop_input_cols = self._drop_input_cols
|
583
|
-
)
|
563
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
584
564
|
|
585
565
|
transform_handlers = ModelTransformerBuilder.build(
|
586
566
|
dataset=dataset,
|
@@ -599,7 +579,11 @@ class LogisticRegression(BaseTransformer):
|
|
599
579
|
return output_df
|
600
580
|
|
601
581
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
602
|
-
def fit_predict(
|
582
|
+
def fit_predict(
|
583
|
+
self,
|
584
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
585
|
+
output_cols_prefix: str = "fit_predict_",
|
586
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
603
587
|
""" Method not supported for this class.
|
604
588
|
|
605
589
|
|
@@ -624,22 +608,104 @@ class LogisticRegression(BaseTransformer):
|
|
624
608
|
)
|
625
609
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
626
610
|
drop_input_cols=self._drop_input_cols,
|
627
|
-
expected_output_cols_list=
|
611
|
+
expected_output_cols_list=(
|
612
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
613
|
+
),
|
628
614
|
)
|
629
615
|
self._sklearn_object = fitted_estimator
|
630
616
|
self._is_fitted = True
|
631
617
|
return output_result
|
632
618
|
|
619
|
+
|
620
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
621
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
622
|
+
""" Method not supported for this class.
|
623
|
+
|
633
624
|
|
634
|
-
|
635
|
-
|
636
|
-
|
625
|
+
Raises:
|
626
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
627
|
+
|
628
|
+
Args:
|
629
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
630
|
+
Snowpark or Pandas DataFrame.
|
631
|
+
output_cols_prefix: Prefix for the response columns
|
637
632
|
Returns:
|
638
633
|
Transformed dataset.
|
639
634
|
"""
|
640
|
-
self.
|
641
|
-
|
642
|
-
|
635
|
+
self._infer_input_output_cols(dataset)
|
636
|
+
super()._check_dataset_type(dataset)
|
637
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
638
|
+
estimator=self._sklearn_object,
|
639
|
+
dataset=dataset,
|
640
|
+
input_cols=self.input_cols,
|
641
|
+
label_cols=self.label_cols,
|
642
|
+
sample_weight_col=self.sample_weight_col,
|
643
|
+
autogenerated=self._autogenerated,
|
644
|
+
subproject=_SUBPROJECT,
|
645
|
+
)
|
646
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
647
|
+
drop_input_cols=self._drop_input_cols,
|
648
|
+
expected_output_cols_list=self.output_cols,
|
649
|
+
)
|
650
|
+
self._sklearn_object = fitted_estimator
|
651
|
+
self._is_fitted = True
|
652
|
+
return output_result
|
653
|
+
|
654
|
+
|
655
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
656
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
657
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
658
|
+
"""
|
659
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
660
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
661
|
+
if output_cols:
|
662
|
+
output_cols = [
|
663
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
664
|
+
for c in output_cols
|
665
|
+
]
|
666
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
667
|
+
output_cols = [output_cols_prefix]
|
668
|
+
elif self._sklearn_object is not None:
|
669
|
+
classes = self._sklearn_object.classes_
|
670
|
+
if isinstance(classes, numpy.ndarray):
|
671
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
672
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
673
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
674
|
+
output_cols = []
|
675
|
+
for i, cl in enumerate(classes):
|
676
|
+
# For binary classification, there is only one output column for each class
|
677
|
+
# ndarray as the two classes are complementary.
|
678
|
+
if len(cl) == 2:
|
679
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
680
|
+
else:
|
681
|
+
output_cols.extend([
|
682
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
683
|
+
])
|
684
|
+
else:
|
685
|
+
output_cols = []
|
686
|
+
|
687
|
+
# Make sure column names are valid snowflake identifiers.
|
688
|
+
assert output_cols is not None # Make MyPy happy
|
689
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
690
|
+
|
691
|
+
return rv
|
692
|
+
|
693
|
+
def _align_expected_output_names(
|
694
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
695
|
+
) -> List[str]:
|
696
|
+
# in case the inferred output column names dimension is different
|
697
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
698
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
699
|
+
output_df_columns = list(output_df_pd.columns)
|
700
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
701
|
+
if self.sample_weight_col:
|
702
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
703
|
+
# if the dimension of inferred output column names is correct; use it
|
704
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
705
|
+
return expected_output_cols_list
|
706
|
+
# otherwise, use the sklearn estimator's output
|
707
|
+
else:
|
708
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
643
709
|
|
644
710
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
645
711
|
@telemetry.send_api_usage_telemetry(
|
@@ -673,24 +739,26 @@ class LogisticRegression(BaseTransformer):
|
|
673
739
|
# are specific to the type of dataset used.
|
674
740
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
675
741
|
|
742
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
743
|
+
|
676
744
|
if isinstance(dataset, DataFrame):
|
677
|
-
self.
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
745
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
746
|
+
self._deps = self._get_dependencies()
|
747
|
+
assert isinstance(
|
748
|
+
dataset._session, Session
|
749
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
682
750
|
transform_kwargs = dict(
|
683
751
|
session=dataset._session,
|
684
752
|
dependencies=self._deps,
|
685
|
-
drop_input_cols
|
753
|
+
drop_input_cols=self._drop_input_cols,
|
686
754
|
expected_output_cols_type="float",
|
687
755
|
)
|
756
|
+
expected_output_cols = self._align_expected_output_names(
|
757
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
758
|
+
)
|
688
759
|
|
689
760
|
elif isinstance(dataset, pd.DataFrame):
|
690
|
-
transform_kwargs = dict(
|
691
|
-
snowpark_input_cols = self._snowpark_cols,
|
692
|
-
drop_input_cols = self._drop_input_cols
|
693
|
-
)
|
761
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
694
762
|
|
695
763
|
transform_handlers = ModelTransformerBuilder.build(
|
696
764
|
dataset=dataset,
|
@@ -702,7 +770,7 @@ class LogisticRegression(BaseTransformer):
|
|
702
770
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
703
771
|
inference_method=inference_method,
|
704
772
|
input_cols=self.input_cols,
|
705
|
-
expected_output_cols=
|
773
|
+
expected_output_cols=expected_output_cols,
|
706
774
|
**transform_kwargs
|
707
775
|
)
|
708
776
|
return output_df
|
@@ -734,29 +802,30 @@ class LogisticRegression(BaseTransformer):
|
|
734
802
|
Output dataset with log probability of the sample for each class in the model.
|
735
803
|
"""
|
736
804
|
super()._check_dataset_type(dataset)
|
737
|
-
inference_method="predict_log_proba"
|
805
|
+
inference_method = "predict_log_proba"
|
806
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
738
807
|
|
739
808
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
740
809
|
# are specific to the type of dataset used.
|
741
810
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
742
811
|
|
743
812
|
if isinstance(dataset, DataFrame):
|
744
|
-
self.
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
815
|
+
assert isinstance(
|
816
|
+
dataset._session, Session
|
817
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
749
818
|
transform_kwargs = dict(
|
750
819
|
session=dataset._session,
|
751
820
|
dependencies=self._deps,
|
752
|
-
drop_input_cols
|
821
|
+
drop_input_cols=self._drop_input_cols,
|
753
822
|
expected_output_cols_type="float",
|
754
823
|
)
|
824
|
+
expected_output_cols = self._align_expected_output_names(
|
825
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
826
|
+
)
|
755
827
|
elif isinstance(dataset, pd.DataFrame):
|
756
|
-
transform_kwargs = dict(
|
757
|
-
snowpark_input_cols = self._snowpark_cols,
|
758
|
-
drop_input_cols = self._drop_input_cols
|
759
|
-
)
|
828
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
760
829
|
|
761
830
|
transform_handlers = ModelTransformerBuilder.build(
|
762
831
|
dataset=dataset,
|
@@ -769,7 +838,7 @@ class LogisticRegression(BaseTransformer):
|
|
769
838
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
770
839
|
inference_method=inference_method,
|
771
840
|
input_cols=self.input_cols,
|
772
|
-
expected_output_cols=
|
841
|
+
expected_output_cols=expected_output_cols,
|
773
842
|
**transform_kwargs
|
774
843
|
)
|
775
844
|
return output_df
|
@@ -797,30 +866,32 @@ class LogisticRegression(BaseTransformer):
|
|
797
866
|
Output dataset with results of the decision function for the samples in input dataset.
|
798
867
|
"""
|
799
868
|
super()._check_dataset_type(dataset)
|
800
|
-
inference_method="decision_function"
|
869
|
+
inference_method = "decision_function"
|
801
870
|
|
802
871
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
803
872
|
# are specific to the type of dataset used.
|
804
873
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
805
874
|
|
875
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
876
|
+
|
806
877
|
if isinstance(dataset, DataFrame):
|
807
|
-
self.
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
878
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
879
|
+
self._deps = self._get_dependencies()
|
880
|
+
assert isinstance(
|
881
|
+
dataset._session, Session
|
882
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
812
883
|
transform_kwargs = dict(
|
813
884
|
session=dataset._session,
|
814
885
|
dependencies=self._deps,
|
815
|
-
drop_input_cols
|
886
|
+
drop_input_cols=self._drop_input_cols,
|
816
887
|
expected_output_cols_type="float",
|
817
888
|
)
|
889
|
+
expected_output_cols = self._align_expected_output_names(
|
890
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
891
|
+
)
|
818
892
|
|
819
893
|
elif isinstance(dataset, pd.DataFrame):
|
820
|
-
transform_kwargs = dict(
|
821
|
-
snowpark_input_cols = self._snowpark_cols,
|
822
|
-
drop_input_cols = self._drop_input_cols
|
823
|
-
)
|
894
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
824
895
|
|
825
896
|
transform_handlers = ModelTransformerBuilder.build(
|
826
897
|
dataset=dataset,
|
@@ -833,7 +904,7 @@ class LogisticRegression(BaseTransformer):
|
|
833
904
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
834
905
|
inference_method=inference_method,
|
835
906
|
input_cols=self.input_cols,
|
836
|
-
expected_output_cols=
|
907
|
+
expected_output_cols=expected_output_cols,
|
837
908
|
**transform_kwargs
|
838
909
|
)
|
839
910
|
return output_df
|
@@ -862,17 +933,17 @@ class LogisticRegression(BaseTransformer):
|
|
862
933
|
Output dataset with probability of the sample for each class in the model.
|
863
934
|
"""
|
864
935
|
super()._check_dataset_type(dataset)
|
865
|
-
inference_method="score_samples"
|
936
|
+
inference_method = "score_samples"
|
866
937
|
|
867
938
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
868
939
|
# are specific to the type of dataset used.
|
869
940
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
870
941
|
|
942
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
943
|
+
|
871
944
|
if isinstance(dataset, DataFrame):
|
872
|
-
self.
|
873
|
-
|
874
|
-
inference_method=inference_method,
|
875
|
-
)
|
945
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
946
|
+
self._deps = self._get_dependencies()
|
876
947
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
877
948
|
transform_kwargs = dict(
|
878
949
|
session=dataset._session,
|
@@ -880,6 +951,9 @@ class LogisticRegression(BaseTransformer):
|
|
880
951
|
drop_input_cols = self._drop_input_cols,
|
881
952
|
expected_output_cols_type="float",
|
882
953
|
)
|
954
|
+
expected_output_cols = self._align_expected_output_names(
|
955
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
956
|
+
)
|
883
957
|
|
884
958
|
elif isinstance(dataset, pd.DataFrame):
|
885
959
|
transform_kwargs = dict(
|
@@ -898,7 +972,7 @@ class LogisticRegression(BaseTransformer):
|
|
898
972
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
899
973
|
inference_method=inference_method,
|
900
974
|
input_cols=self.input_cols,
|
901
|
-
expected_output_cols=
|
975
|
+
expected_output_cols=expected_output_cols,
|
902
976
|
**transform_kwargs
|
903
977
|
)
|
904
978
|
return output_df
|
@@ -933,17 +1007,15 @@ class LogisticRegression(BaseTransformer):
|
|
933
1007
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
934
1008
|
|
935
1009
|
if isinstance(dataset, DataFrame):
|
936
|
-
self.
|
937
|
-
|
938
|
-
inference_method="score",
|
939
|
-
)
|
1010
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1011
|
+
self._deps = self._get_dependencies()
|
940
1012
|
selected_cols = self._get_active_columns()
|
941
1013
|
if len(selected_cols) > 0:
|
942
1014
|
dataset = dataset.select(selected_cols)
|
943
1015
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
944
1016
|
transform_kwargs = dict(
|
945
1017
|
session=dataset._session,
|
946
|
-
dependencies=
|
1018
|
+
dependencies=self._deps,
|
947
1019
|
score_sproc_imports=['sklearn'],
|
948
1020
|
)
|
949
1021
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1008,11 +1080,8 @@ class LogisticRegression(BaseTransformer):
|
|
1008
1080
|
|
1009
1081
|
if isinstance(dataset, DataFrame):
|
1010
1082
|
|
1011
|
-
self.
|
1012
|
-
|
1013
|
-
inference_method=inference_method,
|
1014
|
-
|
1015
|
-
)
|
1083
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1084
|
+
self._deps = self._get_dependencies()
|
1016
1085
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1017
1086
|
transform_kwargs = dict(
|
1018
1087
|
session = dataset._session,
|
@@ -1045,50 +1114,84 @@ class LogisticRegression(BaseTransformer):
|
|
1045
1114
|
)
|
1046
1115
|
return output_df
|
1047
1116
|
|
1117
|
+
|
1118
|
+
|
1119
|
+
def to_sklearn(self) -> Any:
|
1120
|
+
"""Get sklearn.linear_model.LogisticRegression object.
|
1121
|
+
"""
|
1122
|
+
if self._sklearn_object is None:
|
1123
|
+
self._sklearn_object = self._create_sklearn_object()
|
1124
|
+
return self._sklearn_object
|
1125
|
+
|
1126
|
+
def to_xgboost(self) -> Any:
|
1127
|
+
raise exceptions.SnowflakeMLException(
|
1128
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1129
|
+
original_exception=AttributeError(
|
1130
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1131
|
+
"to_xgboost()",
|
1132
|
+
"to_sklearn()"
|
1133
|
+
)
|
1134
|
+
),
|
1135
|
+
)
|
1136
|
+
|
1137
|
+
def to_lightgbm(self) -> Any:
|
1138
|
+
raise exceptions.SnowflakeMLException(
|
1139
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1140
|
+
original_exception=AttributeError(
|
1141
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1142
|
+
"to_lightgbm()",
|
1143
|
+
"to_sklearn()"
|
1144
|
+
)
|
1145
|
+
),
|
1146
|
+
)
|
1147
|
+
|
1148
|
+
def _get_dependencies(self) -> List[str]:
|
1149
|
+
return self._deps
|
1150
|
+
|
1048
1151
|
|
1049
|
-
def
|
1152
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1050
1153
|
self._model_signature_dict = dict()
|
1051
1154
|
|
1052
1155
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1053
1156
|
|
1054
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1157
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1055
1158
|
outputs: List[BaseFeatureSpec] = []
|
1056
1159
|
if hasattr(self, "predict"):
|
1057
1160
|
# keep mypy happy
|
1058
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1161
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1059
1162
|
# For classifier, the type of predict is the same as the type of label
|
1060
|
-
if self._sklearn_object._estimator_type ==
|
1061
|
-
|
1163
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1164
|
+
# label columns is the desired type for output
|
1062
1165
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1063
1166
|
# rename the output columns
|
1064
1167
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1065
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1066
|
-
|
1067
|
-
|
1168
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1169
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1170
|
+
)
|
1068
1171
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1069
1172
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1070
|
-
# Clusterer returns int64 cluster labels.
|
1173
|
+
# Clusterer returns int64 cluster labels.
|
1071
1174
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1072
1175
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1073
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1176
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1177
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1178
|
+
)
|
1179
|
+
|
1077
1180
|
# For regressor, the type of predict is float64
|
1078
|
-
elif self._sklearn_object._estimator_type ==
|
1181
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1079
1182
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1080
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1183
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1184
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1185
|
+
)
|
1186
|
+
|
1084
1187
|
for prob_func in PROB_FUNCTIONS:
|
1085
1188
|
if hasattr(self, prob_func):
|
1086
1189
|
output_cols_prefix: str = f"{prob_func}_"
|
1087
1190
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1088
1191
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1089
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1090
|
-
|
1091
|
-
|
1192
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1193
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1194
|
+
)
|
1092
1195
|
|
1093
1196
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1094
1197
|
items = list(self._model_signature_dict.items())
|
@@ -1101,10 +1204,10 @@ class LogisticRegression(BaseTransformer):
|
|
1101
1204
|
"""Returns model signature of current class.
|
1102
1205
|
|
1103
1206
|
Raises:
|
1104
|
-
|
1207
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1105
1208
|
|
1106
1209
|
Returns:
|
1107
|
-
Dict
|
1210
|
+
Dict with each method and its input output signature
|
1108
1211
|
"""
|
1109
1212
|
if self._model_signature_dict is None:
|
1110
1213
|
raise exceptions.SnowflakeMLException(
|
@@ -1112,35 +1215,3 @@ class LogisticRegression(BaseTransformer):
|
|
1112
1215
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1113
1216
|
)
|
1114
1217
|
return self._model_signature_dict
|
1115
|
-
|
1116
|
-
def to_sklearn(self) -> Any:
|
1117
|
-
"""Get sklearn.linear_model.LogisticRegression object.
|
1118
|
-
"""
|
1119
|
-
if self._sklearn_object is None:
|
1120
|
-
self._sklearn_object = self._create_sklearn_object()
|
1121
|
-
return self._sklearn_object
|
1122
|
-
|
1123
|
-
def to_xgboost(self) -> Any:
|
1124
|
-
raise exceptions.SnowflakeMLException(
|
1125
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1126
|
-
original_exception=AttributeError(
|
1127
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1128
|
-
"to_xgboost()",
|
1129
|
-
"to_sklearn()"
|
1130
|
-
)
|
1131
|
-
),
|
1132
|
-
)
|
1133
|
-
|
1134
|
-
def to_lightgbm(self) -> Any:
|
1135
|
-
raise exceptions.SnowflakeMLException(
|
1136
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1137
|
-
original_exception=AttributeError(
|
1138
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1139
|
-
"to_lightgbm()",
|
1140
|
-
"to_sklearn()"
|
1141
|
-
)
|
1142
|
-
),
|
1143
|
-
)
|
1144
|
-
|
1145
|
-
def _get_dependencies(self) -> List[str]:
|
1146
|
-
return self._deps
|