snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class PCA(BaseTransformer):
|
71
64
|
r"""Principal component analysis (PCA)
|
72
65
|
For more details on this class, see [sklearn.decomposition.PCA]
|
@@ -286,12 +279,7 @@ class PCA(BaseTransformer):
|
|
286
279
|
)
|
287
280
|
return selected_cols
|
288
281
|
|
289
|
-
|
290
|
-
project=_PROJECT,
|
291
|
-
subproject=_SUBPROJECT,
|
292
|
-
custom_tags=dict([("autogen", True)]),
|
293
|
-
)
|
294
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "PCA":
|
282
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "PCA":
|
295
283
|
"""Fit the model with X
|
296
284
|
For more details on this function, see [sklearn.decomposition.PCA.fit]
|
297
285
|
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html#sklearn.decomposition.PCA.fit)
|
@@ -318,12 +306,14 @@ class PCA(BaseTransformer):
|
|
318
306
|
|
319
307
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
320
308
|
|
321
|
-
|
309
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
322
310
|
if SNOWML_SPROC_ENV in os.environ:
|
323
311
|
statement_params = telemetry.get_function_usage_statement_params(
|
324
312
|
project=_PROJECT,
|
325
313
|
subproject=_SUBPROJECT,
|
326
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
314
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
315
|
+
inspect.currentframe(), PCA.__class__.__name__
|
316
|
+
),
|
327
317
|
api_calls=[Session.call],
|
328
318
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
329
319
|
)
|
@@ -344,27 +334,24 @@ class PCA(BaseTransformer):
|
|
344
334
|
)
|
345
335
|
self._sklearn_object = model_trainer.train()
|
346
336
|
self._is_fitted = True
|
347
|
-
self.
|
337
|
+
self._generate_model_signatures(dataset)
|
348
338
|
return self
|
349
339
|
|
350
340
|
def _batch_inference_validate_snowpark(
|
351
341
|
self,
|
352
342
|
dataset: DataFrame,
|
353
343
|
inference_method: str,
|
354
|
-
) ->
|
355
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
356
|
-
return the available package that exists in the snowflake anaconda channel
|
344
|
+
) -> None:
|
345
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
357
346
|
|
358
347
|
Args:
|
359
348
|
dataset: snowpark dataframe
|
360
349
|
inference_method: the inference method such as predict, score...
|
361
|
-
|
350
|
+
|
362
351
|
Raises:
|
363
352
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
364
353
|
SnowflakeMLException: If the session is None, raise error
|
365
354
|
|
366
|
-
Returns:
|
367
|
-
A list of available package that exists in the snowflake anaconda channel
|
368
355
|
"""
|
369
356
|
if not self._is_fitted:
|
370
357
|
raise exceptions.SnowflakeMLException(
|
@@ -382,9 +369,7 @@ class PCA(BaseTransformer):
|
|
382
369
|
"Session must not specified for snowpark dataset."
|
383
370
|
),
|
384
371
|
)
|
385
|
-
|
386
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
387
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
372
|
+
|
388
373
|
|
389
374
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
390
375
|
@telemetry.send_api_usage_telemetry(
|
@@ -418,7 +403,9 @@ class PCA(BaseTransformer):
|
|
418
403
|
# when it is classifier, infer the datatype from label columns
|
419
404
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
420
405
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
421
|
-
label_cols_signatures = [
|
406
|
+
label_cols_signatures = [
|
407
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
408
|
+
]
|
422
409
|
if len(label_cols_signatures) == 0:
|
423
410
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
424
411
|
raise exceptions.SnowflakeMLException(
|
@@ -426,25 +413,23 @@ class PCA(BaseTransformer):
|
|
426
413
|
original_exception=ValueError(error_str),
|
427
414
|
)
|
428
415
|
|
429
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
430
|
-
label_cols_signatures[0].as_snowpark_type()
|
431
|
-
)
|
416
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
432
417
|
|
433
|
-
self.
|
434
|
-
|
418
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
419
|
+
self._deps = self._get_dependencies()
|
420
|
+
assert isinstance(
|
421
|
+
dataset._session, Session
|
422
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
435
423
|
|
436
424
|
transform_kwargs = dict(
|
437
|
-
session
|
438
|
-
dependencies
|
439
|
-
drop_input_cols
|
440
|
-
expected_output_cols_type
|
425
|
+
session=dataset._session,
|
426
|
+
dependencies=self._deps,
|
427
|
+
drop_input_cols=self._drop_input_cols,
|
428
|
+
expected_output_cols_type=expected_type_inferred,
|
441
429
|
)
|
442
430
|
|
443
431
|
elif isinstance(dataset, pd.DataFrame):
|
444
|
-
transform_kwargs = dict(
|
445
|
-
snowpark_input_cols = self._snowpark_cols,
|
446
|
-
drop_input_cols = self._drop_input_cols
|
447
|
-
)
|
432
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
448
433
|
|
449
434
|
transform_handlers = ModelTransformerBuilder.build(
|
450
435
|
dataset=dataset,
|
@@ -486,7 +471,7 @@ class PCA(BaseTransformer):
|
|
486
471
|
Transformed dataset.
|
487
472
|
"""
|
488
473
|
super()._check_dataset_type(dataset)
|
489
|
-
inference_method="transform"
|
474
|
+
inference_method = "transform"
|
490
475
|
|
491
476
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
492
477
|
# are specific to the type of dataset used.
|
@@ -516,24 +501,19 @@ class PCA(BaseTransformer):
|
|
516
501
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
517
502
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
518
503
|
|
519
|
-
self.
|
520
|
-
|
521
|
-
inference_method=inference_method,
|
522
|
-
)
|
504
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
505
|
+
self._deps = self._get_dependencies()
|
523
506
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
524
507
|
|
525
508
|
transform_kwargs = dict(
|
526
|
-
session
|
527
|
-
dependencies
|
528
|
-
drop_input_cols
|
529
|
-
expected_output_cols_type
|
509
|
+
session=dataset._session,
|
510
|
+
dependencies=self._deps,
|
511
|
+
drop_input_cols=self._drop_input_cols,
|
512
|
+
expected_output_cols_type=expected_dtype,
|
530
513
|
)
|
531
514
|
|
532
515
|
elif isinstance(dataset, pd.DataFrame):
|
533
|
-
transform_kwargs = dict(
|
534
|
-
snowpark_input_cols = self._snowpark_cols,
|
535
|
-
drop_input_cols = self._drop_input_cols
|
536
|
-
)
|
516
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
537
517
|
|
538
518
|
transform_handlers = ModelTransformerBuilder.build(
|
539
519
|
dataset=dataset,
|
@@ -552,7 +532,11 @@ class PCA(BaseTransformer):
|
|
552
532
|
return output_df
|
553
533
|
|
554
534
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
555
|
-
def fit_predict(
|
535
|
+
def fit_predict(
|
536
|
+
self,
|
537
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
538
|
+
output_cols_prefix: str = "fit_predict_",
|
539
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
556
540
|
""" Method not supported for this class.
|
557
541
|
|
558
542
|
|
@@ -577,22 +561,106 @@ class PCA(BaseTransformer):
|
|
577
561
|
)
|
578
562
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
579
563
|
drop_input_cols=self._drop_input_cols,
|
580
|
-
expected_output_cols_list=
|
564
|
+
expected_output_cols_list=(
|
565
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
566
|
+
),
|
581
567
|
)
|
582
568
|
self._sklearn_object = fitted_estimator
|
583
569
|
self._is_fitted = True
|
584
570
|
return output_result
|
585
571
|
|
572
|
+
|
573
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
574
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
575
|
+
""" Fit the model with X and apply the dimensionality reduction on X
|
576
|
+
For more details on this function, see [sklearn.decomposition.PCA.fit_transform]
|
577
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html#sklearn.decomposition.PCA.fit_transform)
|
578
|
+
|
579
|
+
|
580
|
+
Raises:
|
581
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
586
582
|
|
587
|
-
|
588
|
-
|
589
|
-
|
583
|
+
Args:
|
584
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
585
|
+
Snowpark or Pandas DataFrame.
|
586
|
+
output_cols_prefix: Prefix for the response columns
|
590
587
|
Returns:
|
591
588
|
Transformed dataset.
|
592
589
|
"""
|
593
|
-
self.
|
594
|
-
|
595
|
-
|
590
|
+
self._infer_input_output_cols(dataset)
|
591
|
+
super()._check_dataset_type(dataset)
|
592
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
593
|
+
estimator=self._sklearn_object,
|
594
|
+
dataset=dataset,
|
595
|
+
input_cols=self.input_cols,
|
596
|
+
label_cols=self.label_cols,
|
597
|
+
sample_weight_col=self.sample_weight_col,
|
598
|
+
autogenerated=self._autogenerated,
|
599
|
+
subproject=_SUBPROJECT,
|
600
|
+
)
|
601
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
602
|
+
drop_input_cols=self._drop_input_cols,
|
603
|
+
expected_output_cols_list=self.output_cols,
|
604
|
+
)
|
605
|
+
self._sklearn_object = fitted_estimator
|
606
|
+
self._is_fitted = True
|
607
|
+
return output_result
|
608
|
+
|
609
|
+
|
610
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
611
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
612
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
613
|
+
"""
|
614
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
615
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
616
|
+
if output_cols:
|
617
|
+
output_cols = [
|
618
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
619
|
+
for c in output_cols
|
620
|
+
]
|
621
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
622
|
+
output_cols = [output_cols_prefix]
|
623
|
+
elif self._sklearn_object is not None:
|
624
|
+
classes = self._sklearn_object.classes_
|
625
|
+
if isinstance(classes, numpy.ndarray):
|
626
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
627
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
628
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
629
|
+
output_cols = []
|
630
|
+
for i, cl in enumerate(classes):
|
631
|
+
# For binary classification, there is only one output column for each class
|
632
|
+
# ndarray as the two classes are complementary.
|
633
|
+
if len(cl) == 2:
|
634
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
635
|
+
else:
|
636
|
+
output_cols.extend([
|
637
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
638
|
+
])
|
639
|
+
else:
|
640
|
+
output_cols = []
|
641
|
+
|
642
|
+
# Make sure column names are valid snowflake identifiers.
|
643
|
+
assert output_cols is not None # Make MyPy happy
|
644
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
645
|
+
|
646
|
+
return rv
|
647
|
+
|
648
|
+
def _align_expected_output_names(
|
649
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
650
|
+
) -> List[str]:
|
651
|
+
# in case the inferred output column names dimension is different
|
652
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
653
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
654
|
+
output_df_columns = list(output_df_pd.columns)
|
655
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
656
|
+
if self.sample_weight_col:
|
657
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
658
|
+
# if the dimension of inferred output column names is correct; use it
|
659
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
660
|
+
return expected_output_cols_list
|
661
|
+
# otherwise, use the sklearn estimator's output
|
662
|
+
else:
|
663
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
596
664
|
|
597
665
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
598
666
|
@telemetry.send_api_usage_telemetry(
|
@@ -624,24 +692,26 @@ class PCA(BaseTransformer):
|
|
624
692
|
# are specific to the type of dataset used.
|
625
693
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
626
694
|
|
695
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
696
|
+
|
627
697
|
if isinstance(dataset, DataFrame):
|
628
|
-
self.
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
698
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
699
|
+
self._deps = self._get_dependencies()
|
700
|
+
assert isinstance(
|
701
|
+
dataset._session, Session
|
702
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
633
703
|
transform_kwargs = dict(
|
634
704
|
session=dataset._session,
|
635
705
|
dependencies=self._deps,
|
636
|
-
drop_input_cols
|
706
|
+
drop_input_cols=self._drop_input_cols,
|
637
707
|
expected_output_cols_type="float",
|
638
708
|
)
|
709
|
+
expected_output_cols = self._align_expected_output_names(
|
710
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
711
|
+
)
|
639
712
|
|
640
713
|
elif isinstance(dataset, pd.DataFrame):
|
641
|
-
transform_kwargs = dict(
|
642
|
-
snowpark_input_cols = self._snowpark_cols,
|
643
|
-
drop_input_cols = self._drop_input_cols
|
644
|
-
)
|
714
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
645
715
|
|
646
716
|
transform_handlers = ModelTransformerBuilder.build(
|
647
717
|
dataset=dataset,
|
@@ -653,7 +723,7 @@ class PCA(BaseTransformer):
|
|
653
723
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
654
724
|
inference_method=inference_method,
|
655
725
|
input_cols=self.input_cols,
|
656
|
-
expected_output_cols=
|
726
|
+
expected_output_cols=expected_output_cols,
|
657
727
|
**transform_kwargs
|
658
728
|
)
|
659
729
|
return output_df
|
@@ -683,29 +753,30 @@ class PCA(BaseTransformer):
|
|
683
753
|
Output dataset with log probability of the sample for each class in the model.
|
684
754
|
"""
|
685
755
|
super()._check_dataset_type(dataset)
|
686
|
-
inference_method="predict_log_proba"
|
756
|
+
inference_method = "predict_log_proba"
|
757
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
687
758
|
|
688
759
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
689
760
|
# are specific to the type of dataset used.
|
690
761
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
691
762
|
|
692
763
|
if isinstance(dataset, DataFrame):
|
693
|
-
self.
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
766
|
+
assert isinstance(
|
767
|
+
dataset._session, Session
|
768
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
698
769
|
transform_kwargs = dict(
|
699
770
|
session=dataset._session,
|
700
771
|
dependencies=self._deps,
|
701
|
-
drop_input_cols
|
772
|
+
drop_input_cols=self._drop_input_cols,
|
702
773
|
expected_output_cols_type="float",
|
703
774
|
)
|
775
|
+
expected_output_cols = self._align_expected_output_names(
|
776
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
777
|
+
)
|
704
778
|
elif isinstance(dataset, pd.DataFrame):
|
705
|
-
transform_kwargs = dict(
|
706
|
-
snowpark_input_cols = self._snowpark_cols,
|
707
|
-
drop_input_cols = self._drop_input_cols
|
708
|
-
)
|
779
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
709
780
|
|
710
781
|
transform_handlers = ModelTransformerBuilder.build(
|
711
782
|
dataset=dataset,
|
@@ -718,7 +789,7 @@ class PCA(BaseTransformer):
|
|
718
789
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
719
790
|
inference_method=inference_method,
|
720
791
|
input_cols=self.input_cols,
|
721
|
-
expected_output_cols=
|
792
|
+
expected_output_cols=expected_output_cols,
|
722
793
|
**transform_kwargs
|
723
794
|
)
|
724
795
|
return output_df
|
@@ -744,30 +815,32 @@ class PCA(BaseTransformer):
|
|
744
815
|
Output dataset with results of the decision function for the samples in input dataset.
|
745
816
|
"""
|
746
817
|
super()._check_dataset_type(dataset)
|
747
|
-
inference_method="decision_function"
|
818
|
+
inference_method = "decision_function"
|
748
819
|
|
749
820
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
750
821
|
# are specific to the type of dataset used.
|
751
822
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
752
823
|
|
824
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
825
|
+
|
753
826
|
if isinstance(dataset, DataFrame):
|
754
|
-
self.
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
827
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
828
|
+
self._deps = self._get_dependencies()
|
829
|
+
assert isinstance(
|
830
|
+
dataset._session, Session
|
831
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
759
832
|
transform_kwargs = dict(
|
760
833
|
session=dataset._session,
|
761
834
|
dependencies=self._deps,
|
762
|
-
drop_input_cols
|
835
|
+
drop_input_cols=self._drop_input_cols,
|
763
836
|
expected_output_cols_type="float",
|
764
837
|
)
|
838
|
+
expected_output_cols = self._align_expected_output_names(
|
839
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
840
|
+
)
|
765
841
|
|
766
842
|
elif isinstance(dataset, pd.DataFrame):
|
767
|
-
transform_kwargs = dict(
|
768
|
-
snowpark_input_cols = self._snowpark_cols,
|
769
|
-
drop_input_cols = self._drop_input_cols
|
770
|
-
)
|
843
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
771
844
|
|
772
845
|
transform_handlers = ModelTransformerBuilder.build(
|
773
846
|
dataset=dataset,
|
@@ -780,7 +853,7 @@ class PCA(BaseTransformer):
|
|
780
853
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
781
854
|
inference_method=inference_method,
|
782
855
|
input_cols=self.input_cols,
|
783
|
-
expected_output_cols=
|
856
|
+
expected_output_cols=expected_output_cols,
|
784
857
|
**transform_kwargs
|
785
858
|
)
|
786
859
|
return output_df
|
@@ -811,17 +884,17 @@ class PCA(BaseTransformer):
|
|
811
884
|
Output dataset with probability of the sample for each class in the model.
|
812
885
|
"""
|
813
886
|
super()._check_dataset_type(dataset)
|
814
|
-
inference_method="score_samples"
|
887
|
+
inference_method = "score_samples"
|
815
888
|
|
816
889
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
817
890
|
# are specific to the type of dataset used.
|
818
891
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
819
892
|
|
893
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
894
|
+
|
820
895
|
if isinstance(dataset, DataFrame):
|
821
|
-
self.
|
822
|
-
|
823
|
-
inference_method=inference_method,
|
824
|
-
)
|
896
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
897
|
+
self._deps = self._get_dependencies()
|
825
898
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
826
899
|
transform_kwargs = dict(
|
827
900
|
session=dataset._session,
|
@@ -829,6 +902,9 @@ class PCA(BaseTransformer):
|
|
829
902
|
drop_input_cols = self._drop_input_cols,
|
830
903
|
expected_output_cols_type="float",
|
831
904
|
)
|
905
|
+
expected_output_cols = self._align_expected_output_names(
|
906
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
907
|
+
)
|
832
908
|
|
833
909
|
elif isinstance(dataset, pd.DataFrame):
|
834
910
|
transform_kwargs = dict(
|
@@ -847,7 +923,7 @@ class PCA(BaseTransformer):
|
|
847
923
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
848
924
|
inference_method=inference_method,
|
849
925
|
input_cols=self.input_cols,
|
850
|
-
expected_output_cols=
|
926
|
+
expected_output_cols=expected_output_cols,
|
851
927
|
**transform_kwargs
|
852
928
|
)
|
853
929
|
return output_df
|
@@ -882,17 +958,15 @@ class PCA(BaseTransformer):
|
|
882
958
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
883
959
|
|
884
960
|
if isinstance(dataset, DataFrame):
|
885
|
-
self.
|
886
|
-
|
887
|
-
inference_method="score",
|
888
|
-
)
|
961
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
962
|
+
self._deps = self._get_dependencies()
|
889
963
|
selected_cols = self._get_active_columns()
|
890
964
|
if len(selected_cols) > 0:
|
891
965
|
dataset = dataset.select(selected_cols)
|
892
966
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
893
967
|
transform_kwargs = dict(
|
894
968
|
session=dataset._session,
|
895
|
-
dependencies=
|
969
|
+
dependencies=self._deps,
|
896
970
|
score_sproc_imports=['sklearn'],
|
897
971
|
)
|
898
972
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -957,11 +1031,8 @@ class PCA(BaseTransformer):
|
|
957
1031
|
|
958
1032
|
if isinstance(dataset, DataFrame):
|
959
1033
|
|
960
|
-
self.
|
961
|
-
|
962
|
-
inference_method=inference_method,
|
963
|
-
|
964
|
-
)
|
1034
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1035
|
+
self._deps = self._get_dependencies()
|
965
1036
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
966
1037
|
transform_kwargs = dict(
|
967
1038
|
session = dataset._session,
|
@@ -994,50 +1065,84 @@ class PCA(BaseTransformer):
|
|
994
1065
|
)
|
995
1066
|
return output_df
|
996
1067
|
|
1068
|
+
|
1069
|
+
|
1070
|
+
def to_sklearn(self) -> Any:
|
1071
|
+
"""Get sklearn.decomposition.PCA object.
|
1072
|
+
"""
|
1073
|
+
if self._sklearn_object is None:
|
1074
|
+
self._sklearn_object = self._create_sklearn_object()
|
1075
|
+
return self._sklearn_object
|
1076
|
+
|
1077
|
+
def to_xgboost(self) -> Any:
|
1078
|
+
raise exceptions.SnowflakeMLException(
|
1079
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1080
|
+
original_exception=AttributeError(
|
1081
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1082
|
+
"to_xgboost()",
|
1083
|
+
"to_sklearn()"
|
1084
|
+
)
|
1085
|
+
),
|
1086
|
+
)
|
997
1087
|
|
998
|
-
def
|
1088
|
+
def to_lightgbm(self) -> Any:
|
1089
|
+
raise exceptions.SnowflakeMLException(
|
1090
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1091
|
+
original_exception=AttributeError(
|
1092
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1093
|
+
"to_lightgbm()",
|
1094
|
+
"to_sklearn()"
|
1095
|
+
)
|
1096
|
+
),
|
1097
|
+
)
|
1098
|
+
|
1099
|
+
def _get_dependencies(self) -> List[str]:
|
1100
|
+
return self._deps
|
1101
|
+
|
1102
|
+
|
1103
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
999
1104
|
self._model_signature_dict = dict()
|
1000
1105
|
|
1001
1106
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1002
1107
|
|
1003
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1108
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1004
1109
|
outputs: List[BaseFeatureSpec] = []
|
1005
1110
|
if hasattr(self, "predict"):
|
1006
1111
|
# keep mypy happy
|
1007
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1112
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1008
1113
|
# For classifier, the type of predict is the same as the type of label
|
1009
|
-
if self._sklearn_object._estimator_type ==
|
1010
|
-
|
1114
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1115
|
+
# label columns is the desired type for output
|
1011
1116
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1012
1117
|
# rename the output columns
|
1013
1118
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1014
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1015
|
-
|
1016
|
-
|
1119
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1120
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1121
|
+
)
|
1017
1122
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1018
1123
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1019
|
-
# Clusterer returns int64 cluster labels.
|
1124
|
+
# Clusterer returns int64 cluster labels.
|
1020
1125
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1021
1126
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1022
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1127
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1128
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1129
|
+
)
|
1130
|
+
|
1026
1131
|
# For regressor, the type of predict is float64
|
1027
|
-
elif self._sklearn_object._estimator_type ==
|
1132
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1028
1133
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1029
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1134
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1135
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1136
|
+
)
|
1137
|
+
|
1033
1138
|
for prob_func in PROB_FUNCTIONS:
|
1034
1139
|
if hasattr(self, prob_func):
|
1035
1140
|
output_cols_prefix: str = f"{prob_func}_"
|
1036
1141
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1037
1142
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1038
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1039
|
-
|
1040
|
-
|
1143
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1144
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1145
|
+
)
|
1041
1146
|
|
1042
1147
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1043
1148
|
items = list(self._model_signature_dict.items())
|
@@ -1050,10 +1155,10 @@ class PCA(BaseTransformer):
|
|
1050
1155
|
"""Returns model signature of current class.
|
1051
1156
|
|
1052
1157
|
Raises:
|
1053
|
-
|
1158
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1054
1159
|
|
1055
1160
|
Returns:
|
1056
|
-
Dict
|
1161
|
+
Dict with each method and its input output signature
|
1057
1162
|
"""
|
1058
1163
|
if self._model_signature_dict is None:
|
1059
1164
|
raise exceptions.SnowflakeMLException(
|
@@ -1061,35 +1166,3 @@ class PCA(BaseTransformer):
|
|
1061
1166
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1062
1167
|
)
|
1063
1168
|
return self._model_signature_dict
|
1064
|
-
|
1065
|
-
def to_sklearn(self) -> Any:
|
1066
|
-
"""Get sklearn.decomposition.PCA object.
|
1067
|
-
"""
|
1068
|
-
if self._sklearn_object is None:
|
1069
|
-
self._sklearn_object = self._create_sklearn_object()
|
1070
|
-
return self._sklearn_object
|
1071
|
-
|
1072
|
-
def to_xgboost(self) -> Any:
|
1073
|
-
raise exceptions.SnowflakeMLException(
|
1074
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1075
|
-
original_exception=AttributeError(
|
1076
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1077
|
-
"to_xgboost()",
|
1078
|
-
"to_sklearn()"
|
1079
|
-
)
|
1080
|
-
),
|
1081
|
-
)
|
1082
|
-
|
1083
|
-
def to_lightgbm(self) -> Any:
|
1084
|
-
raise exceptions.SnowflakeMLException(
|
1085
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1086
|
-
original_exception=AttributeError(
|
1087
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1088
|
-
"to_lightgbm()",
|
1089
|
-
"to_sklearn()"
|
1090
|
-
)
|
1091
|
-
),
|
1092
|
-
)
|
1093
|
-
|
1094
|
-
def _get_dependencies(self) -> List[str]:
|
1095
|
-
return self._deps
|