snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace(
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class MLPClassifier(BaseTransformer):
|
71
64
|
r"""Multi-layer Perceptron classifier
|
72
65
|
For more details on this class, see [sklearn.neural_network.MLPClassifier]
|
@@ -387,12 +380,7 @@ class MLPClassifier(BaseTransformer):
|
|
387
380
|
)
|
388
381
|
return selected_cols
|
389
382
|
|
390
|
-
|
391
|
-
project=_PROJECT,
|
392
|
-
subproject=_SUBPROJECT,
|
393
|
-
custom_tags=dict([("autogen", True)]),
|
394
|
-
)
|
395
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MLPClassifier":
|
383
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MLPClassifier":
|
396
384
|
"""Fit the model to data matrix X and target(s) y
|
397
385
|
For more details on this function, see [sklearn.neural_network.MLPClassifier.fit]
|
398
386
|
(https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html#sklearn.neural_network.MLPClassifier.fit)
|
@@ -419,12 +407,14 @@ class MLPClassifier(BaseTransformer):
|
|
419
407
|
|
420
408
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
421
409
|
|
422
|
-
|
410
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
423
411
|
if SNOWML_SPROC_ENV in os.environ:
|
424
412
|
statement_params = telemetry.get_function_usage_statement_params(
|
425
413
|
project=_PROJECT,
|
426
414
|
subproject=_SUBPROJECT,
|
427
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
415
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
416
|
+
inspect.currentframe(), MLPClassifier.__class__.__name__
|
417
|
+
),
|
428
418
|
api_calls=[Session.call],
|
429
419
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
430
420
|
)
|
@@ -445,27 +435,24 @@ class MLPClassifier(BaseTransformer):
|
|
445
435
|
)
|
446
436
|
self._sklearn_object = model_trainer.train()
|
447
437
|
self._is_fitted = True
|
448
|
-
self.
|
438
|
+
self._generate_model_signatures(dataset)
|
449
439
|
return self
|
450
440
|
|
451
441
|
def _batch_inference_validate_snowpark(
|
452
442
|
self,
|
453
443
|
dataset: DataFrame,
|
454
444
|
inference_method: str,
|
455
|
-
) ->
|
456
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
457
|
-
return the available package that exists in the snowflake anaconda channel
|
445
|
+
) -> None:
|
446
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
458
447
|
|
459
448
|
Args:
|
460
449
|
dataset: snowpark dataframe
|
461
450
|
inference_method: the inference method such as predict, score...
|
462
|
-
|
451
|
+
|
463
452
|
Raises:
|
464
453
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
465
454
|
SnowflakeMLException: If the session is None, raise error
|
466
455
|
|
467
|
-
Returns:
|
468
|
-
A list of available package that exists in the snowflake anaconda channel
|
469
456
|
"""
|
470
457
|
if not self._is_fitted:
|
471
458
|
raise exceptions.SnowflakeMLException(
|
@@ -483,9 +470,7 @@ class MLPClassifier(BaseTransformer):
|
|
483
470
|
"Session must not specified for snowpark dataset."
|
484
471
|
),
|
485
472
|
)
|
486
|
-
|
487
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
488
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
473
|
+
|
489
474
|
|
490
475
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
491
476
|
@telemetry.send_api_usage_telemetry(
|
@@ -521,7 +506,9 @@ class MLPClassifier(BaseTransformer):
|
|
521
506
|
# when it is classifier, infer the datatype from label columns
|
522
507
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
523
508
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
524
|
-
label_cols_signatures = [
|
509
|
+
label_cols_signatures = [
|
510
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
511
|
+
]
|
525
512
|
if len(label_cols_signatures) == 0:
|
526
513
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
527
514
|
raise exceptions.SnowflakeMLException(
|
@@ -529,25 +516,23 @@ class MLPClassifier(BaseTransformer):
|
|
529
516
|
original_exception=ValueError(error_str),
|
530
517
|
)
|
531
518
|
|
532
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
533
|
-
label_cols_signatures[0].as_snowpark_type()
|
534
|
-
)
|
519
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
535
520
|
|
536
|
-
self.
|
537
|
-
|
521
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
522
|
+
self._deps = self._get_dependencies()
|
523
|
+
assert isinstance(
|
524
|
+
dataset._session, Session
|
525
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
538
526
|
|
539
527
|
transform_kwargs = dict(
|
540
|
-
session
|
541
|
-
dependencies
|
542
|
-
drop_input_cols
|
543
|
-
expected_output_cols_type
|
528
|
+
session=dataset._session,
|
529
|
+
dependencies=self._deps,
|
530
|
+
drop_input_cols=self._drop_input_cols,
|
531
|
+
expected_output_cols_type=expected_type_inferred,
|
544
532
|
)
|
545
533
|
|
546
534
|
elif isinstance(dataset, pd.DataFrame):
|
547
|
-
transform_kwargs = dict(
|
548
|
-
snowpark_input_cols = self._snowpark_cols,
|
549
|
-
drop_input_cols = self._drop_input_cols
|
550
|
-
)
|
535
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
551
536
|
|
552
537
|
transform_handlers = ModelTransformerBuilder.build(
|
553
538
|
dataset=dataset,
|
@@ -587,7 +572,7 @@ class MLPClassifier(BaseTransformer):
|
|
587
572
|
Transformed dataset.
|
588
573
|
"""
|
589
574
|
super()._check_dataset_type(dataset)
|
590
|
-
inference_method="transform"
|
575
|
+
inference_method = "transform"
|
591
576
|
|
592
577
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
593
578
|
# are specific to the type of dataset used.
|
@@ -617,24 +602,19 @@ class MLPClassifier(BaseTransformer):
|
|
617
602
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
618
603
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
619
604
|
|
620
|
-
self.
|
621
|
-
|
622
|
-
inference_method=inference_method,
|
623
|
-
)
|
605
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
606
|
+
self._deps = self._get_dependencies()
|
624
607
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
625
608
|
|
626
609
|
transform_kwargs = dict(
|
627
|
-
session
|
628
|
-
dependencies
|
629
|
-
drop_input_cols
|
630
|
-
expected_output_cols_type
|
610
|
+
session=dataset._session,
|
611
|
+
dependencies=self._deps,
|
612
|
+
drop_input_cols=self._drop_input_cols,
|
613
|
+
expected_output_cols_type=expected_dtype,
|
631
614
|
)
|
632
615
|
|
633
616
|
elif isinstance(dataset, pd.DataFrame):
|
634
|
-
transform_kwargs = dict(
|
635
|
-
snowpark_input_cols = self._snowpark_cols,
|
636
|
-
drop_input_cols = self._drop_input_cols
|
637
|
-
)
|
617
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
638
618
|
|
639
619
|
transform_handlers = ModelTransformerBuilder.build(
|
640
620
|
dataset=dataset,
|
@@ -653,7 +633,11 @@ class MLPClassifier(BaseTransformer):
|
|
653
633
|
return output_df
|
654
634
|
|
655
635
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
656
|
-
def fit_predict(
|
636
|
+
def fit_predict(
|
637
|
+
self,
|
638
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
639
|
+
output_cols_prefix: str = "fit_predict_",
|
640
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
657
641
|
""" Method not supported for this class.
|
658
642
|
|
659
643
|
|
@@ -678,22 +662,104 @@ class MLPClassifier(BaseTransformer):
|
|
678
662
|
)
|
679
663
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
680
664
|
drop_input_cols=self._drop_input_cols,
|
681
|
-
expected_output_cols_list=
|
665
|
+
expected_output_cols_list=(
|
666
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
667
|
+
),
|
682
668
|
)
|
683
669
|
self._sklearn_object = fitted_estimator
|
684
670
|
self._is_fitted = True
|
685
671
|
return output_result
|
686
672
|
|
673
|
+
|
674
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
675
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
676
|
+
""" Method not supported for this class.
|
677
|
+
|
687
678
|
|
688
|
-
|
689
|
-
|
690
|
-
|
679
|
+
Raises:
|
680
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
681
|
+
|
682
|
+
Args:
|
683
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
684
|
+
Snowpark or Pandas DataFrame.
|
685
|
+
output_cols_prefix: Prefix for the response columns
|
691
686
|
Returns:
|
692
687
|
Transformed dataset.
|
693
688
|
"""
|
694
|
-
self.
|
695
|
-
|
696
|
-
|
689
|
+
self._infer_input_output_cols(dataset)
|
690
|
+
super()._check_dataset_type(dataset)
|
691
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
692
|
+
estimator=self._sklearn_object,
|
693
|
+
dataset=dataset,
|
694
|
+
input_cols=self.input_cols,
|
695
|
+
label_cols=self.label_cols,
|
696
|
+
sample_weight_col=self.sample_weight_col,
|
697
|
+
autogenerated=self._autogenerated,
|
698
|
+
subproject=_SUBPROJECT,
|
699
|
+
)
|
700
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
701
|
+
drop_input_cols=self._drop_input_cols,
|
702
|
+
expected_output_cols_list=self.output_cols,
|
703
|
+
)
|
704
|
+
self._sklearn_object = fitted_estimator
|
705
|
+
self._is_fitted = True
|
706
|
+
return output_result
|
707
|
+
|
708
|
+
|
709
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
710
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
711
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
712
|
+
"""
|
713
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
714
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
715
|
+
if output_cols:
|
716
|
+
output_cols = [
|
717
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
718
|
+
for c in output_cols
|
719
|
+
]
|
720
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
721
|
+
output_cols = [output_cols_prefix]
|
722
|
+
elif self._sklearn_object is not None:
|
723
|
+
classes = self._sklearn_object.classes_
|
724
|
+
if isinstance(classes, numpy.ndarray):
|
725
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
726
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
727
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
728
|
+
output_cols = []
|
729
|
+
for i, cl in enumerate(classes):
|
730
|
+
# For binary classification, there is only one output column for each class
|
731
|
+
# ndarray as the two classes are complementary.
|
732
|
+
if len(cl) == 2:
|
733
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
734
|
+
else:
|
735
|
+
output_cols.extend([
|
736
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
737
|
+
])
|
738
|
+
else:
|
739
|
+
output_cols = []
|
740
|
+
|
741
|
+
# Make sure column names are valid snowflake identifiers.
|
742
|
+
assert output_cols is not None # Make MyPy happy
|
743
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
744
|
+
|
745
|
+
return rv
|
746
|
+
|
747
|
+
def _align_expected_output_names(
|
748
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
749
|
+
) -> List[str]:
|
750
|
+
# in case the inferred output column names dimension is different
|
751
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
752
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
753
|
+
output_df_columns = list(output_df_pd.columns)
|
754
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
755
|
+
if self.sample_weight_col:
|
756
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
757
|
+
# if the dimension of inferred output column names is correct; use it
|
758
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
759
|
+
return expected_output_cols_list
|
760
|
+
# otherwise, use the sklearn estimator's output
|
761
|
+
else:
|
762
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
697
763
|
|
698
764
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
699
765
|
@telemetry.send_api_usage_telemetry(
|
@@ -727,24 +793,26 @@ class MLPClassifier(BaseTransformer):
|
|
727
793
|
# are specific to the type of dataset used.
|
728
794
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
729
795
|
|
796
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
797
|
+
|
730
798
|
if isinstance(dataset, DataFrame):
|
731
|
-
self.
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
799
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
800
|
+
self._deps = self._get_dependencies()
|
801
|
+
assert isinstance(
|
802
|
+
dataset._session, Session
|
803
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
736
804
|
transform_kwargs = dict(
|
737
805
|
session=dataset._session,
|
738
806
|
dependencies=self._deps,
|
739
|
-
drop_input_cols
|
807
|
+
drop_input_cols=self._drop_input_cols,
|
740
808
|
expected_output_cols_type="float",
|
741
809
|
)
|
810
|
+
expected_output_cols = self._align_expected_output_names(
|
811
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
812
|
+
)
|
742
813
|
|
743
814
|
elif isinstance(dataset, pd.DataFrame):
|
744
|
-
transform_kwargs = dict(
|
745
|
-
snowpark_input_cols = self._snowpark_cols,
|
746
|
-
drop_input_cols = self._drop_input_cols
|
747
|
-
)
|
815
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
748
816
|
|
749
817
|
transform_handlers = ModelTransformerBuilder.build(
|
750
818
|
dataset=dataset,
|
@@ -756,7 +824,7 @@ class MLPClassifier(BaseTransformer):
|
|
756
824
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
757
825
|
inference_method=inference_method,
|
758
826
|
input_cols=self.input_cols,
|
759
|
-
expected_output_cols=
|
827
|
+
expected_output_cols=expected_output_cols,
|
760
828
|
**transform_kwargs
|
761
829
|
)
|
762
830
|
return output_df
|
@@ -788,29 +856,30 @@ class MLPClassifier(BaseTransformer):
|
|
788
856
|
Output dataset with log probability of the sample for each class in the model.
|
789
857
|
"""
|
790
858
|
super()._check_dataset_type(dataset)
|
791
|
-
inference_method="predict_log_proba"
|
859
|
+
inference_method = "predict_log_proba"
|
860
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
792
861
|
|
793
862
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
794
863
|
# are specific to the type of dataset used.
|
795
864
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
796
865
|
|
797
866
|
if isinstance(dataset, DataFrame):
|
798
|
-
self.
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
867
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
868
|
+
self._deps = self._get_dependencies()
|
869
|
+
assert isinstance(
|
870
|
+
dataset._session, Session
|
871
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
803
872
|
transform_kwargs = dict(
|
804
873
|
session=dataset._session,
|
805
874
|
dependencies=self._deps,
|
806
|
-
drop_input_cols
|
875
|
+
drop_input_cols=self._drop_input_cols,
|
807
876
|
expected_output_cols_type="float",
|
808
877
|
)
|
878
|
+
expected_output_cols = self._align_expected_output_names(
|
879
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
880
|
+
)
|
809
881
|
elif isinstance(dataset, pd.DataFrame):
|
810
|
-
transform_kwargs = dict(
|
811
|
-
snowpark_input_cols = self._snowpark_cols,
|
812
|
-
drop_input_cols = self._drop_input_cols
|
813
|
-
)
|
882
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
814
883
|
|
815
884
|
transform_handlers = ModelTransformerBuilder.build(
|
816
885
|
dataset=dataset,
|
@@ -823,7 +892,7 @@ class MLPClassifier(BaseTransformer):
|
|
823
892
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
824
893
|
inference_method=inference_method,
|
825
894
|
input_cols=self.input_cols,
|
826
|
-
expected_output_cols=
|
895
|
+
expected_output_cols=expected_output_cols,
|
827
896
|
**transform_kwargs
|
828
897
|
)
|
829
898
|
return output_df
|
@@ -849,30 +918,32 @@ class MLPClassifier(BaseTransformer):
|
|
849
918
|
Output dataset with results of the decision function for the samples in input dataset.
|
850
919
|
"""
|
851
920
|
super()._check_dataset_type(dataset)
|
852
|
-
inference_method="decision_function"
|
921
|
+
inference_method = "decision_function"
|
853
922
|
|
854
923
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
855
924
|
# are specific to the type of dataset used.
|
856
925
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
857
926
|
|
927
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
928
|
+
|
858
929
|
if isinstance(dataset, DataFrame):
|
859
|
-
self.
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
930
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
931
|
+
self._deps = self._get_dependencies()
|
932
|
+
assert isinstance(
|
933
|
+
dataset._session, Session
|
934
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
864
935
|
transform_kwargs = dict(
|
865
936
|
session=dataset._session,
|
866
937
|
dependencies=self._deps,
|
867
|
-
drop_input_cols
|
938
|
+
drop_input_cols=self._drop_input_cols,
|
868
939
|
expected_output_cols_type="float",
|
869
940
|
)
|
941
|
+
expected_output_cols = self._align_expected_output_names(
|
942
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
943
|
+
)
|
870
944
|
|
871
945
|
elif isinstance(dataset, pd.DataFrame):
|
872
|
-
transform_kwargs = dict(
|
873
|
-
snowpark_input_cols = self._snowpark_cols,
|
874
|
-
drop_input_cols = self._drop_input_cols
|
875
|
-
)
|
946
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
876
947
|
|
877
948
|
transform_handlers = ModelTransformerBuilder.build(
|
878
949
|
dataset=dataset,
|
@@ -885,7 +956,7 @@ class MLPClassifier(BaseTransformer):
|
|
885
956
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
886
957
|
inference_method=inference_method,
|
887
958
|
input_cols=self.input_cols,
|
888
|
-
expected_output_cols=
|
959
|
+
expected_output_cols=expected_output_cols,
|
889
960
|
**transform_kwargs
|
890
961
|
)
|
891
962
|
return output_df
|
@@ -914,17 +985,17 @@ class MLPClassifier(BaseTransformer):
|
|
914
985
|
Output dataset with probability of the sample for each class in the model.
|
915
986
|
"""
|
916
987
|
super()._check_dataset_type(dataset)
|
917
|
-
inference_method="score_samples"
|
988
|
+
inference_method = "score_samples"
|
918
989
|
|
919
990
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
920
991
|
# are specific to the type of dataset used.
|
921
992
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
922
993
|
|
994
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
995
|
+
|
923
996
|
if isinstance(dataset, DataFrame):
|
924
|
-
self.
|
925
|
-
|
926
|
-
inference_method=inference_method,
|
927
|
-
)
|
997
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
998
|
+
self._deps = self._get_dependencies()
|
928
999
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
929
1000
|
transform_kwargs = dict(
|
930
1001
|
session=dataset._session,
|
@@ -932,6 +1003,9 @@ class MLPClassifier(BaseTransformer):
|
|
932
1003
|
drop_input_cols = self._drop_input_cols,
|
933
1004
|
expected_output_cols_type="float",
|
934
1005
|
)
|
1006
|
+
expected_output_cols = self._align_expected_output_names(
|
1007
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
1008
|
+
)
|
935
1009
|
|
936
1010
|
elif isinstance(dataset, pd.DataFrame):
|
937
1011
|
transform_kwargs = dict(
|
@@ -950,7 +1024,7 @@ class MLPClassifier(BaseTransformer):
|
|
950
1024
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
951
1025
|
inference_method=inference_method,
|
952
1026
|
input_cols=self.input_cols,
|
953
|
-
expected_output_cols=
|
1027
|
+
expected_output_cols=expected_output_cols,
|
954
1028
|
**transform_kwargs
|
955
1029
|
)
|
956
1030
|
return output_df
|
@@ -985,17 +1059,15 @@ class MLPClassifier(BaseTransformer):
|
|
985
1059
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
986
1060
|
|
987
1061
|
if isinstance(dataset, DataFrame):
|
988
|
-
self.
|
989
|
-
|
990
|
-
inference_method="score",
|
991
|
-
)
|
1062
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1063
|
+
self._deps = self._get_dependencies()
|
992
1064
|
selected_cols = self._get_active_columns()
|
993
1065
|
if len(selected_cols) > 0:
|
994
1066
|
dataset = dataset.select(selected_cols)
|
995
1067
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
996
1068
|
transform_kwargs = dict(
|
997
1069
|
session=dataset._session,
|
998
|
-
dependencies=
|
1070
|
+
dependencies=self._deps,
|
999
1071
|
score_sproc_imports=['sklearn'],
|
1000
1072
|
)
|
1001
1073
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1060,11 +1132,8 @@ class MLPClassifier(BaseTransformer):
|
|
1060
1132
|
|
1061
1133
|
if isinstance(dataset, DataFrame):
|
1062
1134
|
|
1063
|
-
self.
|
1064
|
-
|
1065
|
-
inference_method=inference_method,
|
1066
|
-
|
1067
|
-
)
|
1135
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1136
|
+
self._deps = self._get_dependencies()
|
1068
1137
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1069
1138
|
transform_kwargs = dict(
|
1070
1139
|
session = dataset._session,
|
@@ -1097,50 +1166,84 @@ class MLPClassifier(BaseTransformer):
|
|
1097
1166
|
)
|
1098
1167
|
return output_df
|
1099
1168
|
|
1169
|
+
|
1170
|
+
|
1171
|
+
def to_sklearn(self) -> Any:
|
1172
|
+
"""Get sklearn.neural_network.MLPClassifier object.
|
1173
|
+
"""
|
1174
|
+
if self._sklearn_object is None:
|
1175
|
+
self._sklearn_object = self._create_sklearn_object()
|
1176
|
+
return self._sklearn_object
|
1177
|
+
|
1178
|
+
def to_xgboost(self) -> Any:
|
1179
|
+
raise exceptions.SnowflakeMLException(
|
1180
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1181
|
+
original_exception=AttributeError(
|
1182
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1183
|
+
"to_xgboost()",
|
1184
|
+
"to_sklearn()"
|
1185
|
+
)
|
1186
|
+
),
|
1187
|
+
)
|
1188
|
+
|
1189
|
+
def to_lightgbm(self) -> Any:
|
1190
|
+
raise exceptions.SnowflakeMLException(
|
1191
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1192
|
+
original_exception=AttributeError(
|
1193
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1194
|
+
"to_lightgbm()",
|
1195
|
+
"to_sklearn()"
|
1196
|
+
)
|
1197
|
+
),
|
1198
|
+
)
|
1199
|
+
|
1200
|
+
def _get_dependencies(self) -> List[str]:
|
1201
|
+
return self._deps
|
1202
|
+
|
1100
1203
|
|
1101
|
-
def
|
1204
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1102
1205
|
self._model_signature_dict = dict()
|
1103
1206
|
|
1104
1207
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1105
1208
|
|
1106
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1209
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1107
1210
|
outputs: List[BaseFeatureSpec] = []
|
1108
1211
|
if hasattr(self, "predict"):
|
1109
1212
|
# keep mypy happy
|
1110
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1213
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1111
1214
|
# For classifier, the type of predict is the same as the type of label
|
1112
|
-
if self._sklearn_object._estimator_type ==
|
1113
|
-
|
1215
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1216
|
+
# label columns is the desired type for output
|
1114
1217
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1115
1218
|
# rename the output columns
|
1116
1219
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1117
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1118
|
-
|
1119
|
-
|
1220
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1221
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1222
|
+
)
|
1120
1223
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1121
1224
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1122
|
-
# Clusterer returns int64 cluster labels.
|
1225
|
+
# Clusterer returns int64 cluster labels.
|
1123
1226
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1124
1227
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1125
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1228
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1229
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1230
|
+
)
|
1231
|
+
|
1129
1232
|
# For regressor, the type of predict is float64
|
1130
|
-
elif self._sklearn_object._estimator_type ==
|
1233
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1131
1234
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1132
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1235
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1236
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1237
|
+
)
|
1238
|
+
|
1136
1239
|
for prob_func in PROB_FUNCTIONS:
|
1137
1240
|
if hasattr(self, prob_func):
|
1138
1241
|
output_cols_prefix: str = f"{prob_func}_"
|
1139
1242
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1140
1243
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1141
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1142
|
-
|
1143
|
-
|
1244
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1245
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1246
|
+
)
|
1144
1247
|
|
1145
1248
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1146
1249
|
items = list(self._model_signature_dict.items())
|
@@ -1153,10 +1256,10 @@ class MLPClassifier(BaseTransformer):
|
|
1153
1256
|
"""Returns model signature of current class.
|
1154
1257
|
|
1155
1258
|
Raises:
|
1156
|
-
|
1259
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1157
1260
|
|
1158
1261
|
Returns:
|
1159
|
-
Dict
|
1262
|
+
Dict with each method and its input output signature
|
1160
1263
|
"""
|
1161
1264
|
if self._model_signature_dict is None:
|
1162
1265
|
raise exceptions.SnowflakeMLException(
|
@@ -1164,35 +1267,3 @@ class MLPClassifier(BaseTransformer):
|
|
1164
1267
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1165
1268
|
)
|
1166
1269
|
return self._model_signature_dict
|
1167
|
-
|
1168
|
-
def to_sklearn(self) -> Any:
|
1169
|
-
"""Get sklearn.neural_network.MLPClassifier object.
|
1170
|
-
"""
|
1171
|
-
if self._sklearn_object is None:
|
1172
|
-
self._sklearn_object = self._create_sklearn_object()
|
1173
|
-
return self._sklearn_object
|
1174
|
-
|
1175
|
-
def to_xgboost(self) -> Any:
|
1176
|
-
raise exceptions.SnowflakeMLException(
|
1177
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1178
|
-
original_exception=AttributeError(
|
1179
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1180
|
-
"to_xgboost()",
|
1181
|
-
"to_sklearn()"
|
1182
|
-
)
|
1183
|
-
),
|
1184
|
-
)
|
1185
|
-
|
1186
|
-
def to_lightgbm(self) -> Any:
|
1187
|
-
raise exceptions.SnowflakeMLException(
|
1188
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1189
|
-
original_exception=AttributeError(
|
1190
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1191
|
-
"to_lightgbm()",
|
1192
|
-
"to_sklearn()"
|
1193
|
-
)
|
1194
|
-
),
|
1195
|
-
)
|
1196
|
-
|
1197
|
-
def _get_dependencies(self) -> List[str]:
|
1198
|
-
return self._deps
|