snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class Perceptron(BaseTransformer):
|
71
64
|
r"""Linear perceptron classifier
|
72
65
|
For more details on this class, see [sklearn.linear_model.Perceptron]
|
@@ -300,12 +293,7 @@ class Perceptron(BaseTransformer):
|
|
300
293
|
)
|
301
294
|
return selected_cols
|
302
295
|
|
303
|
-
|
304
|
-
project=_PROJECT,
|
305
|
-
subproject=_SUBPROJECT,
|
306
|
-
custom_tags=dict([("autogen", True)]),
|
307
|
-
)
|
308
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "Perceptron":
|
296
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "Perceptron":
|
309
297
|
"""Fit linear model with Stochastic Gradient Descent
|
310
298
|
For more details on this function, see [sklearn.linear_model.Perceptron.fit]
|
311
299
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Perceptron.html#sklearn.linear_model.Perceptron.fit)
|
@@ -332,12 +320,14 @@ class Perceptron(BaseTransformer):
|
|
332
320
|
|
333
321
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
334
322
|
|
335
|
-
|
323
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
336
324
|
if SNOWML_SPROC_ENV in os.environ:
|
337
325
|
statement_params = telemetry.get_function_usage_statement_params(
|
338
326
|
project=_PROJECT,
|
339
327
|
subproject=_SUBPROJECT,
|
340
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
328
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
329
|
+
inspect.currentframe(), Perceptron.__class__.__name__
|
330
|
+
),
|
341
331
|
api_calls=[Session.call],
|
342
332
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
343
333
|
)
|
@@ -358,27 +348,24 @@ class Perceptron(BaseTransformer):
|
|
358
348
|
)
|
359
349
|
self._sklearn_object = model_trainer.train()
|
360
350
|
self._is_fitted = True
|
361
|
-
self.
|
351
|
+
self._generate_model_signatures(dataset)
|
362
352
|
return self
|
363
353
|
|
364
354
|
def _batch_inference_validate_snowpark(
|
365
355
|
self,
|
366
356
|
dataset: DataFrame,
|
367
357
|
inference_method: str,
|
368
|
-
) ->
|
369
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
370
|
-
return the available package that exists in the snowflake anaconda channel
|
358
|
+
) -> None:
|
359
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
371
360
|
|
372
361
|
Args:
|
373
362
|
dataset: snowpark dataframe
|
374
363
|
inference_method: the inference method such as predict, score...
|
375
|
-
|
364
|
+
|
376
365
|
Raises:
|
377
366
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
378
367
|
SnowflakeMLException: If the session is None, raise error
|
379
368
|
|
380
|
-
Returns:
|
381
|
-
A list of available package that exists in the snowflake anaconda channel
|
382
369
|
"""
|
383
370
|
if not self._is_fitted:
|
384
371
|
raise exceptions.SnowflakeMLException(
|
@@ -396,9 +383,7 @@ class Perceptron(BaseTransformer):
|
|
396
383
|
"Session must not specified for snowpark dataset."
|
397
384
|
),
|
398
385
|
)
|
399
|
-
|
400
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
401
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
386
|
+
|
402
387
|
|
403
388
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
404
389
|
@telemetry.send_api_usage_telemetry(
|
@@ -434,7 +419,9 @@ class Perceptron(BaseTransformer):
|
|
434
419
|
# when it is classifier, infer the datatype from label columns
|
435
420
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
436
421
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
437
|
-
label_cols_signatures = [
|
422
|
+
label_cols_signatures = [
|
423
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
424
|
+
]
|
438
425
|
if len(label_cols_signatures) == 0:
|
439
426
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
440
427
|
raise exceptions.SnowflakeMLException(
|
@@ -442,25 +429,23 @@ class Perceptron(BaseTransformer):
|
|
442
429
|
original_exception=ValueError(error_str),
|
443
430
|
)
|
444
431
|
|
445
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
446
|
-
label_cols_signatures[0].as_snowpark_type()
|
447
|
-
)
|
432
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
448
433
|
|
449
|
-
self.
|
450
|
-
|
434
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
435
|
+
self._deps = self._get_dependencies()
|
436
|
+
assert isinstance(
|
437
|
+
dataset._session, Session
|
438
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
451
439
|
|
452
440
|
transform_kwargs = dict(
|
453
|
-
session
|
454
|
-
dependencies
|
455
|
-
drop_input_cols
|
456
|
-
expected_output_cols_type
|
441
|
+
session=dataset._session,
|
442
|
+
dependencies=self._deps,
|
443
|
+
drop_input_cols=self._drop_input_cols,
|
444
|
+
expected_output_cols_type=expected_type_inferred,
|
457
445
|
)
|
458
446
|
|
459
447
|
elif isinstance(dataset, pd.DataFrame):
|
460
|
-
transform_kwargs = dict(
|
461
|
-
snowpark_input_cols = self._snowpark_cols,
|
462
|
-
drop_input_cols = self._drop_input_cols
|
463
|
-
)
|
448
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
464
449
|
|
465
450
|
transform_handlers = ModelTransformerBuilder.build(
|
466
451
|
dataset=dataset,
|
@@ -500,7 +485,7 @@ class Perceptron(BaseTransformer):
|
|
500
485
|
Transformed dataset.
|
501
486
|
"""
|
502
487
|
super()._check_dataset_type(dataset)
|
503
|
-
inference_method="transform"
|
488
|
+
inference_method = "transform"
|
504
489
|
|
505
490
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
506
491
|
# are specific to the type of dataset used.
|
@@ -530,24 +515,19 @@ class Perceptron(BaseTransformer):
|
|
530
515
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
531
516
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
532
517
|
|
533
|
-
self.
|
534
|
-
|
535
|
-
inference_method=inference_method,
|
536
|
-
)
|
518
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
519
|
+
self._deps = self._get_dependencies()
|
537
520
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
538
521
|
|
539
522
|
transform_kwargs = dict(
|
540
|
-
session
|
541
|
-
dependencies
|
542
|
-
drop_input_cols
|
543
|
-
expected_output_cols_type
|
523
|
+
session=dataset._session,
|
524
|
+
dependencies=self._deps,
|
525
|
+
drop_input_cols=self._drop_input_cols,
|
526
|
+
expected_output_cols_type=expected_dtype,
|
544
527
|
)
|
545
528
|
|
546
529
|
elif isinstance(dataset, pd.DataFrame):
|
547
|
-
transform_kwargs = dict(
|
548
|
-
snowpark_input_cols = self._snowpark_cols,
|
549
|
-
drop_input_cols = self._drop_input_cols
|
550
|
-
)
|
530
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
551
531
|
|
552
532
|
transform_handlers = ModelTransformerBuilder.build(
|
553
533
|
dataset=dataset,
|
@@ -566,7 +546,11 @@ class Perceptron(BaseTransformer):
|
|
566
546
|
return output_df
|
567
547
|
|
568
548
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
569
|
-
def fit_predict(
|
549
|
+
def fit_predict(
|
550
|
+
self,
|
551
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
552
|
+
output_cols_prefix: str = "fit_predict_",
|
553
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
570
554
|
""" Method not supported for this class.
|
571
555
|
|
572
556
|
|
@@ -591,22 +575,104 @@ class Perceptron(BaseTransformer):
|
|
591
575
|
)
|
592
576
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
593
577
|
drop_input_cols=self._drop_input_cols,
|
594
|
-
expected_output_cols_list=
|
578
|
+
expected_output_cols_list=(
|
579
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
580
|
+
),
|
595
581
|
)
|
596
582
|
self._sklearn_object = fitted_estimator
|
597
583
|
self._is_fitted = True
|
598
584
|
return output_result
|
599
585
|
|
586
|
+
|
587
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
588
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
589
|
+
""" Method not supported for this class.
|
590
|
+
|
600
591
|
|
601
|
-
|
602
|
-
|
603
|
-
|
592
|
+
Raises:
|
593
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
594
|
+
|
595
|
+
Args:
|
596
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
597
|
+
Snowpark or Pandas DataFrame.
|
598
|
+
output_cols_prefix: Prefix for the response columns
|
604
599
|
Returns:
|
605
600
|
Transformed dataset.
|
606
601
|
"""
|
607
|
-
self.
|
608
|
-
|
609
|
-
|
602
|
+
self._infer_input_output_cols(dataset)
|
603
|
+
super()._check_dataset_type(dataset)
|
604
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
605
|
+
estimator=self._sklearn_object,
|
606
|
+
dataset=dataset,
|
607
|
+
input_cols=self.input_cols,
|
608
|
+
label_cols=self.label_cols,
|
609
|
+
sample_weight_col=self.sample_weight_col,
|
610
|
+
autogenerated=self._autogenerated,
|
611
|
+
subproject=_SUBPROJECT,
|
612
|
+
)
|
613
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
614
|
+
drop_input_cols=self._drop_input_cols,
|
615
|
+
expected_output_cols_list=self.output_cols,
|
616
|
+
)
|
617
|
+
self._sklearn_object = fitted_estimator
|
618
|
+
self._is_fitted = True
|
619
|
+
return output_result
|
620
|
+
|
621
|
+
|
622
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
623
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
624
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
625
|
+
"""
|
626
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
627
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
628
|
+
if output_cols:
|
629
|
+
output_cols = [
|
630
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
631
|
+
for c in output_cols
|
632
|
+
]
|
633
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
634
|
+
output_cols = [output_cols_prefix]
|
635
|
+
elif self._sklearn_object is not None:
|
636
|
+
classes = self._sklearn_object.classes_
|
637
|
+
if isinstance(classes, numpy.ndarray):
|
638
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
639
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
640
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
641
|
+
output_cols = []
|
642
|
+
for i, cl in enumerate(classes):
|
643
|
+
# For binary classification, there is only one output column for each class
|
644
|
+
# ndarray as the two classes are complementary.
|
645
|
+
if len(cl) == 2:
|
646
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
647
|
+
else:
|
648
|
+
output_cols.extend([
|
649
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
650
|
+
])
|
651
|
+
else:
|
652
|
+
output_cols = []
|
653
|
+
|
654
|
+
# Make sure column names are valid snowflake identifiers.
|
655
|
+
assert output_cols is not None # Make MyPy happy
|
656
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
657
|
+
|
658
|
+
return rv
|
659
|
+
|
660
|
+
def _align_expected_output_names(
|
661
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
662
|
+
) -> List[str]:
|
663
|
+
# in case the inferred output column names dimension is different
|
664
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
665
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
666
|
+
output_df_columns = list(output_df_pd.columns)
|
667
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
668
|
+
if self.sample_weight_col:
|
669
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
670
|
+
# if the dimension of inferred output column names is correct; use it
|
671
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
672
|
+
return expected_output_cols_list
|
673
|
+
# otherwise, use the sklearn estimator's output
|
674
|
+
else:
|
675
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
610
676
|
|
611
677
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
612
678
|
@telemetry.send_api_usage_telemetry(
|
@@ -638,24 +704,26 @@ class Perceptron(BaseTransformer):
|
|
638
704
|
# are specific to the type of dataset used.
|
639
705
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
640
706
|
|
707
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
708
|
+
|
641
709
|
if isinstance(dataset, DataFrame):
|
642
|
-
self.
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
710
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
711
|
+
self._deps = self._get_dependencies()
|
712
|
+
assert isinstance(
|
713
|
+
dataset._session, Session
|
714
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
647
715
|
transform_kwargs = dict(
|
648
716
|
session=dataset._session,
|
649
717
|
dependencies=self._deps,
|
650
|
-
drop_input_cols
|
718
|
+
drop_input_cols=self._drop_input_cols,
|
651
719
|
expected_output_cols_type="float",
|
652
720
|
)
|
721
|
+
expected_output_cols = self._align_expected_output_names(
|
722
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
723
|
+
)
|
653
724
|
|
654
725
|
elif isinstance(dataset, pd.DataFrame):
|
655
|
-
transform_kwargs = dict(
|
656
|
-
snowpark_input_cols = self._snowpark_cols,
|
657
|
-
drop_input_cols = self._drop_input_cols
|
658
|
-
)
|
726
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
659
727
|
|
660
728
|
transform_handlers = ModelTransformerBuilder.build(
|
661
729
|
dataset=dataset,
|
@@ -667,7 +735,7 @@ class Perceptron(BaseTransformer):
|
|
667
735
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
668
736
|
inference_method=inference_method,
|
669
737
|
input_cols=self.input_cols,
|
670
|
-
expected_output_cols=
|
738
|
+
expected_output_cols=expected_output_cols,
|
671
739
|
**transform_kwargs
|
672
740
|
)
|
673
741
|
return output_df
|
@@ -697,29 +765,30 @@ class Perceptron(BaseTransformer):
|
|
697
765
|
Output dataset with log probability of the sample for each class in the model.
|
698
766
|
"""
|
699
767
|
super()._check_dataset_type(dataset)
|
700
|
-
inference_method="predict_log_proba"
|
768
|
+
inference_method = "predict_log_proba"
|
769
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
701
770
|
|
702
771
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
703
772
|
# are specific to the type of dataset used.
|
704
773
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
705
774
|
|
706
775
|
if isinstance(dataset, DataFrame):
|
707
|
-
self.
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
776
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
777
|
+
self._deps = self._get_dependencies()
|
778
|
+
assert isinstance(
|
779
|
+
dataset._session, Session
|
780
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
712
781
|
transform_kwargs = dict(
|
713
782
|
session=dataset._session,
|
714
783
|
dependencies=self._deps,
|
715
|
-
drop_input_cols
|
784
|
+
drop_input_cols=self._drop_input_cols,
|
716
785
|
expected_output_cols_type="float",
|
717
786
|
)
|
787
|
+
expected_output_cols = self._align_expected_output_names(
|
788
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
789
|
+
)
|
718
790
|
elif isinstance(dataset, pd.DataFrame):
|
719
|
-
transform_kwargs = dict(
|
720
|
-
snowpark_input_cols = self._snowpark_cols,
|
721
|
-
drop_input_cols = self._drop_input_cols
|
722
|
-
)
|
791
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
723
792
|
|
724
793
|
transform_handlers = ModelTransformerBuilder.build(
|
725
794
|
dataset=dataset,
|
@@ -732,7 +801,7 @@ class Perceptron(BaseTransformer):
|
|
732
801
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
733
802
|
inference_method=inference_method,
|
734
803
|
input_cols=self.input_cols,
|
735
|
-
expected_output_cols=
|
804
|
+
expected_output_cols=expected_output_cols,
|
736
805
|
**transform_kwargs
|
737
806
|
)
|
738
807
|
return output_df
|
@@ -760,30 +829,32 @@ class Perceptron(BaseTransformer):
|
|
760
829
|
Output dataset with results of the decision function for the samples in input dataset.
|
761
830
|
"""
|
762
831
|
super()._check_dataset_type(dataset)
|
763
|
-
inference_method="decision_function"
|
832
|
+
inference_method = "decision_function"
|
764
833
|
|
765
834
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
766
835
|
# are specific to the type of dataset used.
|
767
836
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
768
837
|
|
838
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
839
|
+
|
769
840
|
if isinstance(dataset, DataFrame):
|
770
|
-
self.
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
841
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
842
|
+
self._deps = self._get_dependencies()
|
843
|
+
assert isinstance(
|
844
|
+
dataset._session, Session
|
845
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
775
846
|
transform_kwargs = dict(
|
776
847
|
session=dataset._session,
|
777
848
|
dependencies=self._deps,
|
778
|
-
drop_input_cols
|
849
|
+
drop_input_cols=self._drop_input_cols,
|
779
850
|
expected_output_cols_type="float",
|
780
851
|
)
|
852
|
+
expected_output_cols = self._align_expected_output_names(
|
853
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
854
|
+
)
|
781
855
|
|
782
856
|
elif isinstance(dataset, pd.DataFrame):
|
783
|
-
transform_kwargs = dict(
|
784
|
-
snowpark_input_cols = self._snowpark_cols,
|
785
|
-
drop_input_cols = self._drop_input_cols
|
786
|
-
)
|
857
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
787
858
|
|
788
859
|
transform_handlers = ModelTransformerBuilder.build(
|
789
860
|
dataset=dataset,
|
@@ -796,7 +867,7 @@ class Perceptron(BaseTransformer):
|
|
796
867
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
797
868
|
inference_method=inference_method,
|
798
869
|
input_cols=self.input_cols,
|
799
|
-
expected_output_cols=
|
870
|
+
expected_output_cols=expected_output_cols,
|
800
871
|
**transform_kwargs
|
801
872
|
)
|
802
873
|
return output_df
|
@@ -825,17 +896,17 @@ class Perceptron(BaseTransformer):
|
|
825
896
|
Output dataset with probability of the sample for each class in the model.
|
826
897
|
"""
|
827
898
|
super()._check_dataset_type(dataset)
|
828
|
-
inference_method="score_samples"
|
899
|
+
inference_method = "score_samples"
|
829
900
|
|
830
901
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
831
902
|
# are specific to the type of dataset used.
|
832
903
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
833
904
|
|
905
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
906
|
+
|
834
907
|
if isinstance(dataset, DataFrame):
|
835
|
-
self.
|
836
|
-
|
837
|
-
inference_method=inference_method,
|
838
|
-
)
|
908
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
909
|
+
self._deps = self._get_dependencies()
|
839
910
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
840
911
|
transform_kwargs = dict(
|
841
912
|
session=dataset._session,
|
@@ -843,6 +914,9 @@ class Perceptron(BaseTransformer):
|
|
843
914
|
drop_input_cols = self._drop_input_cols,
|
844
915
|
expected_output_cols_type="float",
|
845
916
|
)
|
917
|
+
expected_output_cols = self._align_expected_output_names(
|
918
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
919
|
+
)
|
846
920
|
|
847
921
|
elif isinstance(dataset, pd.DataFrame):
|
848
922
|
transform_kwargs = dict(
|
@@ -861,7 +935,7 @@ class Perceptron(BaseTransformer):
|
|
861
935
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
862
936
|
inference_method=inference_method,
|
863
937
|
input_cols=self.input_cols,
|
864
|
-
expected_output_cols=
|
938
|
+
expected_output_cols=expected_output_cols,
|
865
939
|
**transform_kwargs
|
866
940
|
)
|
867
941
|
return output_df
|
@@ -896,17 +970,15 @@ class Perceptron(BaseTransformer):
|
|
896
970
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
897
971
|
|
898
972
|
if isinstance(dataset, DataFrame):
|
899
|
-
self.
|
900
|
-
|
901
|
-
inference_method="score",
|
902
|
-
)
|
973
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
974
|
+
self._deps = self._get_dependencies()
|
903
975
|
selected_cols = self._get_active_columns()
|
904
976
|
if len(selected_cols) > 0:
|
905
977
|
dataset = dataset.select(selected_cols)
|
906
978
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
907
979
|
transform_kwargs = dict(
|
908
980
|
session=dataset._session,
|
909
|
-
dependencies=
|
981
|
+
dependencies=self._deps,
|
910
982
|
score_sproc_imports=['sklearn'],
|
911
983
|
)
|
912
984
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -971,11 +1043,8 @@ class Perceptron(BaseTransformer):
|
|
971
1043
|
|
972
1044
|
if isinstance(dataset, DataFrame):
|
973
1045
|
|
974
|
-
self.
|
975
|
-
|
976
|
-
inference_method=inference_method,
|
977
|
-
|
978
|
-
)
|
1046
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1047
|
+
self._deps = self._get_dependencies()
|
979
1048
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
980
1049
|
transform_kwargs = dict(
|
981
1050
|
session = dataset._session,
|
@@ -1008,50 +1077,84 @@ class Perceptron(BaseTransformer):
|
|
1008
1077
|
)
|
1009
1078
|
return output_df
|
1010
1079
|
|
1080
|
+
|
1081
|
+
|
1082
|
+
def to_sklearn(self) -> Any:
|
1083
|
+
"""Get sklearn.linear_model.Perceptron object.
|
1084
|
+
"""
|
1085
|
+
if self._sklearn_object is None:
|
1086
|
+
self._sklearn_object = self._create_sklearn_object()
|
1087
|
+
return self._sklearn_object
|
1088
|
+
|
1089
|
+
def to_xgboost(self) -> Any:
|
1090
|
+
raise exceptions.SnowflakeMLException(
|
1091
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1092
|
+
original_exception=AttributeError(
|
1093
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1094
|
+
"to_xgboost()",
|
1095
|
+
"to_sklearn()"
|
1096
|
+
)
|
1097
|
+
),
|
1098
|
+
)
|
1099
|
+
|
1100
|
+
def to_lightgbm(self) -> Any:
|
1101
|
+
raise exceptions.SnowflakeMLException(
|
1102
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1103
|
+
original_exception=AttributeError(
|
1104
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1105
|
+
"to_lightgbm()",
|
1106
|
+
"to_sklearn()"
|
1107
|
+
)
|
1108
|
+
),
|
1109
|
+
)
|
1110
|
+
|
1111
|
+
def _get_dependencies(self) -> List[str]:
|
1112
|
+
return self._deps
|
1113
|
+
|
1011
1114
|
|
1012
|
-
def
|
1115
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1013
1116
|
self._model_signature_dict = dict()
|
1014
1117
|
|
1015
1118
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1016
1119
|
|
1017
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1120
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1018
1121
|
outputs: List[BaseFeatureSpec] = []
|
1019
1122
|
if hasattr(self, "predict"):
|
1020
1123
|
# keep mypy happy
|
1021
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1124
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1022
1125
|
# For classifier, the type of predict is the same as the type of label
|
1023
|
-
if self._sklearn_object._estimator_type ==
|
1024
|
-
|
1126
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1127
|
+
# label columns is the desired type for output
|
1025
1128
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1026
1129
|
# rename the output columns
|
1027
1130
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1028
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1029
|
-
|
1030
|
-
|
1131
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1132
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1133
|
+
)
|
1031
1134
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1032
1135
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1033
|
-
# Clusterer returns int64 cluster labels.
|
1136
|
+
# Clusterer returns int64 cluster labels.
|
1034
1137
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1035
1138
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1036
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1139
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1140
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1141
|
+
)
|
1142
|
+
|
1040
1143
|
# For regressor, the type of predict is float64
|
1041
|
-
elif self._sklearn_object._estimator_type ==
|
1144
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1042
1145
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1043
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1146
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1147
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1148
|
+
)
|
1149
|
+
|
1047
1150
|
for prob_func in PROB_FUNCTIONS:
|
1048
1151
|
if hasattr(self, prob_func):
|
1049
1152
|
output_cols_prefix: str = f"{prob_func}_"
|
1050
1153
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1051
1154
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1052
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1053
|
-
|
1054
|
-
|
1155
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1156
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1157
|
+
)
|
1055
1158
|
|
1056
1159
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1057
1160
|
items = list(self._model_signature_dict.items())
|
@@ -1064,10 +1167,10 @@ class Perceptron(BaseTransformer):
|
|
1064
1167
|
"""Returns model signature of current class.
|
1065
1168
|
|
1066
1169
|
Raises:
|
1067
|
-
|
1170
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1068
1171
|
|
1069
1172
|
Returns:
|
1070
|
-
Dict
|
1173
|
+
Dict with each method and its input output signature
|
1071
1174
|
"""
|
1072
1175
|
if self._model_signature_dict is None:
|
1073
1176
|
raise exceptions.SnowflakeMLException(
|
@@ -1075,35 +1178,3 @@ class Perceptron(BaseTransformer):
|
|
1075
1178
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1076
1179
|
)
|
1077
1180
|
return self._model_signature_dict
|
1078
|
-
|
1079
|
-
def to_sklearn(self) -> Any:
|
1080
|
-
"""Get sklearn.linear_model.Perceptron object.
|
1081
|
-
"""
|
1082
|
-
if self._sklearn_object is None:
|
1083
|
-
self._sklearn_object = self._create_sklearn_object()
|
1084
|
-
return self._sklearn_object
|
1085
|
-
|
1086
|
-
def to_xgboost(self) -> Any:
|
1087
|
-
raise exceptions.SnowflakeMLException(
|
1088
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1089
|
-
original_exception=AttributeError(
|
1090
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1091
|
-
"to_xgboost()",
|
1092
|
-
"to_sklearn()"
|
1093
|
-
)
|
1094
|
-
),
|
1095
|
-
)
|
1096
|
-
|
1097
|
-
def to_lightgbm(self) -> Any:
|
1098
|
-
raise exceptions.SnowflakeMLException(
|
1099
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1100
|
-
original_exception=AttributeError(
|
1101
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1102
|
-
"to_lightgbm()",
|
1103
|
-
"to_sklearn()"
|
1104
|
-
)
|
1105
|
-
),
|
1106
|
-
)
|
1107
|
-
|
1108
|
-
def _get_dependencies(self) -> List[str]:
|
1109
|
-
return self._deps
|