snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklear
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class BayesianGaussianMixture(BaseTransformer):
|
71
64
|
r"""Variational Bayesian estimation of a Gaussian mixture
|
72
65
|
For more details on this class, see [sklearn.mixture.BayesianGaussianMixture]
|
@@ -325,12 +318,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
325
318
|
)
|
326
319
|
return selected_cols
|
327
320
|
|
328
|
-
|
329
|
-
project=_PROJECT,
|
330
|
-
subproject=_SUBPROJECT,
|
331
|
-
custom_tags=dict([("autogen", True)]),
|
332
|
-
)
|
333
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "BayesianGaussianMixture":
|
321
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "BayesianGaussianMixture":
|
334
322
|
"""Estimate model parameters with the EM algorithm
|
335
323
|
For more details on this function, see [sklearn.mixture.BayesianGaussianMixture.fit]
|
336
324
|
(https://scikit-learn.org/stable/modules/generated/sklearn.mixture.BayesianGaussianMixture.html#sklearn.mixture.BayesianGaussianMixture.fit)
|
@@ -357,12 +345,14 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
357
345
|
|
358
346
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
359
347
|
|
360
|
-
|
348
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
361
349
|
if SNOWML_SPROC_ENV in os.environ:
|
362
350
|
statement_params = telemetry.get_function_usage_statement_params(
|
363
351
|
project=_PROJECT,
|
364
352
|
subproject=_SUBPROJECT,
|
365
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
353
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
354
|
+
inspect.currentframe(), BayesianGaussianMixture.__class__.__name__
|
355
|
+
),
|
366
356
|
api_calls=[Session.call],
|
367
357
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
368
358
|
)
|
@@ -383,27 +373,24 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
383
373
|
)
|
384
374
|
self._sklearn_object = model_trainer.train()
|
385
375
|
self._is_fitted = True
|
386
|
-
self.
|
376
|
+
self._generate_model_signatures(dataset)
|
387
377
|
return self
|
388
378
|
|
389
379
|
def _batch_inference_validate_snowpark(
|
390
380
|
self,
|
391
381
|
dataset: DataFrame,
|
392
382
|
inference_method: str,
|
393
|
-
) ->
|
394
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
395
|
-
return the available package that exists in the snowflake anaconda channel
|
383
|
+
) -> None:
|
384
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
396
385
|
|
397
386
|
Args:
|
398
387
|
dataset: snowpark dataframe
|
399
388
|
inference_method: the inference method such as predict, score...
|
400
|
-
|
389
|
+
|
401
390
|
Raises:
|
402
391
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
403
392
|
SnowflakeMLException: If the session is None, raise error
|
404
393
|
|
405
|
-
Returns:
|
406
|
-
A list of available package that exists in the snowflake anaconda channel
|
407
394
|
"""
|
408
395
|
if not self._is_fitted:
|
409
396
|
raise exceptions.SnowflakeMLException(
|
@@ -421,9 +408,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
421
408
|
"Session must not specified for snowpark dataset."
|
422
409
|
),
|
423
410
|
)
|
424
|
-
|
425
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
426
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
411
|
+
|
427
412
|
|
428
413
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
429
414
|
@telemetry.send_api_usage_telemetry(
|
@@ -459,7 +444,9 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
459
444
|
# when it is classifier, infer the datatype from label columns
|
460
445
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
461
446
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
462
|
-
label_cols_signatures = [
|
447
|
+
label_cols_signatures = [
|
448
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
449
|
+
]
|
463
450
|
if len(label_cols_signatures) == 0:
|
464
451
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
465
452
|
raise exceptions.SnowflakeMLException(
|
@@ -467,25 +454,23 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
467
454
|
original_exception=ValueError(error_str),
|
468
455
|
)
|
469
456
|
|
470
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
471
|
-
label_cols_signatures[0].as_snowpark_type()
|
472
|
-
)
|
457
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
473
458
|
|
474
|
-
self.
|
475
|
-
|
459
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
460
|
+
self._deps = self._get_dependencies()
|
461
|
+
assert isinstance(
|
462
|
+
dataset._session, Session
|
463
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
476
464
|
|
477
465
|
transform_kwargs = dict(
|
478
|
-
session
|
479
|
-
dependencies
|
480
|
-
drop_input_cols
|
481
|
-
expected_output_cols_type
|
466
|
+
session=dataset._session,
|
467
|
+
dependencies=self._deps,
|
468
|
+
drop_input_cols=self._drop_input_cols,
|
469
|
+
expected_output_cols_type=expected_type_inferred,
|
482
470
|
)
|
483
471
|
|
484
472
|
elif isinstance(dataset, pd.DataFrame):
|
485
|
-
transform_kwargs = dict(
|
486
|
-
snowpark_input_cols = self._snowpark_cols,
|
487
|
-
drop_input_cols = self._drop_input_cols
|
488
|
-
)
|
473
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
489
474
|
|
490
475
|
transform_handlers = ModelTransformerBuilder.build(
|
491
476
|
dataset=dataset,
|
@@ -525,7 +510,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
525
510
|
Transformed dataset.
|
526
511
|
"""
|
527
512
|
super()._check_dataset_type(dataset)
|
528
|
-
inference_method="transform"
|
513
|
+
inference_method = "transform"
|
529
514
|
|
530
515
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
531
516
|
# are specific to the type of dataset used.
|
@@ -555,24 +540,19 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
555
540
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
556
541
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
557
542
|
|
558
|
-
self.
|
559
|
-
|
560
|
-
inference_method=inference_method,
|
561
|
-
)
|
543
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
544
|
+
self._deps = self._get_dependencies()
|
562
545
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
563
546
|
|
564
547
|
transform_kwargs = dict(
|
565
|
-
session
|
566
|
-
dependencies
|
567
|
-
drop_input_cols
|
568
|
-
expected_output_cols_type
|
548
|
+
session=dataset._session,
|
549
|
+
dependencies=self._deps,
|
550
|
+
drop_input_cols=self._drop_input_cols,
|
551
|
+
expected_output_cols_type=expected_dtype,
|
569
552
|
)
|
570
553
|
|
571
554
|
elif isinstance(dataset, pd.DataFrame):
|
572
|
-
transform_kwargs = dict(
|
573
|
-
snowpark_input_cols = self._snowpark_cols,
|
574
|
-
drop_input_cols = self._drop_input_cols
|
575
|
-
)
|
555
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
576
556
|
|
577
557
|
transform_handlers = ModelTransformerBuilder.build(
|
578
558
|
dataset=dataset,
|
@@ -591,7 +571,11 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
591
571
|
return output_df
|
592
572
|
|
593
573
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
594
|
-
def fit_predict(
|
574
|
+
def fit_predict(
|
575
|
+
self,
|
576
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
577
|
+
output_cols_prefix: str = "fit_predict_",
|
578
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
595
579
|
""" Estimate model parameters using X and predict the labels for X
|
596
580
|
For more details on this function, see [sklearn.mixture.BayesianGaussianMixture.fit_predict]
|
597
581
|
(https://scikit-learn.org/stable/modules/generated/sklearn.mixture.BayesianGaussianMixture.html#sklearn.mixture.BayesianGaussianMixture.fit_predict)
|
@@ -618,22 +602,104 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
618
602
|
)
|
619
603
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
620
604
|
drop_input_cols=self._drop_input_cols,
|
621
|
-
expected_output_cols_list=
|
605
|
+
expected_output_cols_list=(
|
606
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
607
|
+
),
|
622
608
|
)
|
623
609
|
self._sklearn_object = fitted_estimator
|
624
610
|
self._is_fitted = True
|
625
611
|
return output_result
|
626
612
|
|
613
|
+
|
614
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
615
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
616
|
+
""" Method not supported for this class.
|
617
|
+
|
618
|
+
|
619
|
+
Raises:
|
620
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
627
621
|
|
628
|
-
|
629
|
-
|
630
|
-
|
622
|
+
Args:
|
623
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
624
|
+
Snowpark or Pandas DataFrame.
|
625
|
+
output_cols_prefix: Prefix for the response columns
|
631
626
|
Returns:
|
632
627
|
Transformed dataset.
|
633
628
|
"""
|
634
|
-
self.
|
635
|
-
|
636
|
-
|
629
|
+
self._infer_input_output_cols(dataset)
|
630
|
+
super()._check_dataset_type(dataset)
|
631
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
632
|
+
estimator=self._sklearn_object,
|
633
|
+
dataset=dataset,
|
634
|
+
input_cols=self.input_cols,
|
635
|
+
label_cols=self.label_cols,
|
636
|
+
sample_weight_col=self.sample_weight_col,
|
637
|
+
autogenerated=self._autogenerated,
|
638
|
+
subproject=_SUBPROJECT,
|
639
|
+
)
|
640
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
641
|
+
drop_input_cols=self._drop_input_cols,
|
642
|
+
expected_output_cols_list=self.output_cols,
|
643
|
+
)
|
644
|
+
self._sklearn_object = fitted_estimator
|
645
|
+
self._is_fitted = True
|
646
|
+
return output_result
|
647
|
+
|
648
|
+
|
649
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
650
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
651
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
652
|
+
"""
|
653
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
654
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
655
|
+
if output_cols:
|
656
|
+
output_cols = [
|
657
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
658
|
+
for c in output_cols
|
659
|
+
]
|
660
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
661
|
+
output_cols = [output_cols_prefix]
|
662
|
+
elif self._sklearn_object is not None:
|
663
|
+
classes = self._sklearn_object.classes_
|
664
|
+
if isinstance(classes, numpy.ndarray):
|
665
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
666
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
667
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
668
|
+
output_cols = []
|
669
|
+
for i, cl in enumerate(classes):
|
670
|
+
# For binary classification, there is only one output column for each class
|
671
|
+
# ndarray as the two classes are complementary.
|
672
|
+
if len(cl) == 2:
|
673
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
674
|
+
else:
|
675
|
+
output_cols.extend([
|
676
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
677
|
+
])
|
678
|
+
else:
|
679
|
+
output_cols = []
|
680
|
+
|
681
|
+
# Make sure column names are valid snowflake identifiers.
|
682
|
+
assert output_cols is not None # Make MyPy happy
|
683
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
684
|
+
|
685
|
+
return rv
|
686
|
+
|
687
|
+
def _align_expected_output_names(
|
688
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
689
|
+
) -> List[str]:
|
690
|
+
# in case the inferred output column names dimension is different
|
691
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
692
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
693
|
+
output_df_columns = list(output_df_pd.columns)
|
694
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
695
|
+
if self.sample_weight_col:
|
696
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
697
|
+
# if the dimension of inferred output column names is correct; use it
|
698
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
699
|
+
return expected_output_cols_list
|
700
|
+
# otherwise, use the sklearn estimator's output
|
701
|
+
else:
|
702
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
637
703
|
|
638
704
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
639
705
|
@telemetry.send_api_usage_telemetry(
|
@@ -667,24 +733,26 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
667
733
|
# are specific to the type of dataset used.
|
668
734
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
669
735
|
|
736
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
737
|
+
|
670
738
|
if isinstance(dataset, DataFrame):
|
671
|
-
self.
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
739
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
740
|
+
self._deps = self._get_dependencies()
|
741
|
+
assert isinstance(
|
742
|
+
dataset._session, Session
|
743
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
676
744
|
transform_kwargs = dict(
|
677
745
|
session=dataset._session,
|
678
746
|
dependencies=self._deps,
|
679
|
-
drop_input_cols
|
747
|
+
drop_input_cols=self._drop_input_cols,
|
680
748
|
expected_output_cols_type="float",
|
681
749
|
)
|
750
|
+
expected_output_cols = self._align_expected_output_names(
|
751
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
752
|
+
)
|
682
753
|
|
683
754
|
elif isinstance(dataset, pd.DataFrame):
|
684
|
-
transform_kwargs = dict(
|
685
|
-
snowpark_input_cols = self._snowpark_cols,
|
686
|
-
drop_input_cols = self._drop_input_cols
|
687
|
-
)
|
755
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
688
756
|
|
689
757
|
transform_handlers = ModelTransformerBuilder.build(
|
690
758
|
dataset=dataset,
|
@@ -696,7 +764,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
696
764
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
697
765
|
inference_method=inference_method,
|
698
766
|
input_cols=self.input_cols,
|
699
|
-
expected_output_cols=
|
767
|
+
expected_output_cols=expected_output_cols,
|
700
768
|
**transform_kwargs
|
701
769
|
)
|
702
770
|
return output_df
|
@@ -728,29 +796,30 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
728
796
|
Output dataset with log probability of the sample for each class in the model.
|
729
797
|
"""
|
730
798
|
super()._check_dataset_type(dataset)
|
731
|
-
inference_method="predict_log_proba"
|
799
|
+
inference_method = "predict_log_proba"
|
800
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
732
801
|
|
733
802
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
734
803
|
# are specific to the type of dataset used.
|
735
804
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
736
805
|
|
737
806
|
if isinstance(dataset, DataFrame):
|
738
|
-
self.
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
807
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
808
|
+
self._deps = self._get_dependencies()
|
809
|
+
assert isinstance(
|
810
|
+
dataset._session, Session
|
811
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
743
812
|
transform_kwargs = dict(
|
744
813
|
session=dataset._session,
|
745
814
|
dependencies=self._deps,
|
746
|
-
drop_input_cols
|
815
|
+
drop_input_cols=self._drop_input_cols,
|
747
816
|
expected_output_cols_type="float",
|
748
817
|
)
|
818
|
+
expected_output_cols = self._align_expected_output_names(
|
819
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
820
|
+
)
|
749
821
|
elif isinstance(dataset, pd.DataFrame):
|
750
|
-
transform_kwargs = dict(
|
751
|
-
snowpark_input_cols = self._snowpark_cols,
|
752
|
-
drop_input_cols = self._drop_input_cols
|
753
|
-
)
|
822
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
754
823
|
|
755
824
|
transform_handlers = ModelTransformerBuilder.build(
|
756
825
|
dataset=dataset,
|
@@ -763,7 +832,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
763
832
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
764
833
|
inference_method=inference_method,
|
765
834
|
input_cols=self.input_cols,
|
766
|
-
expected_output_cols=
|
835
|
+
expected_output_cols=expected_output_cols,
|
767
836
|
**transform_kwargs
|
768
837
|
)
|
769
838
|
return output_df
|
@@ -789,30 +858,32 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
789
858
|
Output dataset with results of the decision function for the samples in input dataset.
|
790
859
|
"""
|
791
860
|
super()._check_dataset_type(dataset)
|
792
|
-
inference_method="decision_function"
|
861
|
+
inference_method = "decision_function"
|
793
862
|
|
794
863
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
795
864
|
# are specific to the type of dataset used.
|
796
865
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
797
866
|
|
867
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
868
|
+
|
798
869
|
if isinstance(dataset, DataFrame):
|
799
|
-
self.
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
870
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
871
|
+
self._deps = self._get_dependencies()
|
872
|
+
assert isinstance(
|
873
|
+
dataset._session, Session
|
874
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
804
875
|
transform_kwargs = dict(
|
805
876
|
session=dataset._session,
|
806
877
|
dependencies=self._deps,
|
807
|
-
drop_input_cols
|
878
|
+
drop_input_cols=self._drop_input_cols,
|
808
879
|
expected_output_cols_type="float",
|
809
880
|
)
|
881
|
+
expected_output_cols = self._align_expected_output_names(
|
882
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
883
|
+
)
|
810
884
|
|
811
885
|
elif isinstance(dataset, pd.DataFrame):
|
812
|
-
transform_kwargs = dict(
|
813
|
-
snowpark_input_cols = self._snowpark_cols,
|
814
|
-
drop_input_cols = self._drop_input_cols
|
815
|
-
)
|
886
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
816
887
|
|
817
888
|
transform_handlers = ModelTransformerBuilder.build(
|
818
889
|
dataset=dataset,
|
@@ -825,7 +896,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
825
896
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
826
897
|
inference_method=inference_method,
|
827
898
|
input_cols=self.input_cols,
|
828
|
-
expected_output_cols=
|
899
|
+
expected_output_cols=expected_output_cols,
|
829
900
|
**transform_kwargs
|
830
901
|
)
|
831
902
|
return output_df
|
@@ -856,17 +927,17 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
856
927
|
Output dataset with probability of the sample for each class in the model.
|
857
928
|
"""
|
858
929
|
super()._check_dataset_type(dataset)
|
859
|
-
inference_method="score_samples"
|
930
|
+
inference_method = "score_samples"
|
860
931
|
|
861
932
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
862
933
|
# are specific to the type of dataset used.
|
863
934
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
864
935
|
|
936
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
937
|
+
|
865
938
|
if isinstance(dataset, DataFrame):
|
866
|
-
self.
|
867
|
-
|
868
|
-
inference_method=inference_method,
|
869
|
-
)
|
939
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
940
|
+
self._deps = self._get_dependencies()
|
870
941
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
871
942
|
transform_kwargs = dict(
|
872
943
|
session=dataset._session,
|
@@ -874,6 +945,9 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
874
945
|
drop_input_cols = self._drop_input_cols,
|
875
946
|
expected_output_cols_type="float",
|
876
947
|
)
|
948
|
+
expected_output_cols = self._align_expected_output_names(
|
949
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
950
|
+
)
|
877
951
|
|
878
952
|
elif isinstance(dataset, pd.DataFrame):
|
879
953
|
transform_kwargs = dict(
|
@@ -892,7 +966,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
892
966
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
893
967
|
inference_method=inference_method,
|
894
968
|
input_cols=self.input_cols,
|
895
|
-
expected_output_cols=
|
969
|
+
expected_output_cols=expected_output_cols,
|
896
970
|
**transform_kwargs
|
897
971
|
)
|
898
972
|
return output_df
|
@@ -927,17 +1001,15 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
927
1001
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
928
1002
|
|
929
1003
|
if isinstance(dataset, DataFrame):
|
930
|
-
self.
|
931
|
-
|
932
|
-
inference_method="score",
|
933
|
-
)
|
1004
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1005
|
+
self._deps = self._get_dependencies()
|
934
1006
|
selected_cols = self._get_active_columns()
|
935
1007
|
if len(selected_cols) > 0:
|
936
1008
|
dataset = dataset.select(selected_cols)
|
937
1009
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
938
1010
|
transform_kwargs = dict(
|
939
1011
|
session=dataset._session,
|
940
|
-
dependencies=
|
1012
|
+
dependencies=self._deps,
|
941
1013
|
score_sproc_imports=['sklearn'],
|
942
1014
|
)
|
943
1015
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1002,11 +1074,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
1002
1074
|
|
1003
1075
|
if isinstance(dataset, DataFrame):
|
1004
1076
|
|
1005
|
-
self.
|
1006
|
-
|
1007
|
-
inference_method=inference_method,
|
1008
|
-
|
1009
|
-
)
|
1077
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1078
|
+
self._deps = self._get_dependencies()
|
1010
1079
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1011
1080
|
transform_kwargs = dict(
|
1012
1081
|
session = dataset._session,
|
@@ -1039,50 +1108,84 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
1039
1108
|
)
|
1040
1109
|
return output_df
|
1041
1110
|
|
1111
|
+
|
1112
|
+
|
1113
|
+
def to_sklearn(self) -> Any:
|
1114
|
+
"""Get sklearn.mixture.BayesianGaussianMixture object.
|
1115
|
+
"""
|
1116
|
+
if self._sklearn_object is None:
|
1117
|
+
self._sklearn_object = self._create_sklearn_object()
|
1118
|
+
return self._sklearn_object
|
1119
|
+
|
1120
|
+
def to_xgboost(self) -> Any:
|
1121
|
+
raise exceptions.SnowflakeMLException(
|
1122
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1123
|
+
original_exception=AttributeError(
|
1124
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1125
|
+
"to_xgboost()",
|
1126
|
+
"to_sklearn()"
|
1127
|
+
)
|
1128
|
+
),
|
1129
|
+
)
|
1042
1130
|
|
1043
|
-
def
|
1131
|
+
def to_lightgbm(self) -> Any:
|
1132
|
+
raise exceptions.SnowflakeMLException(
|
1133
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1134
|
+
original_exception=AttributeError(
|
1135
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1136
|
+
"to_lightgbm()",
|
1137
|
+
"to_sklearn()"
|
1138
|
+
)
|
1139
|
+
),
|
1140
|
+
)
|
1141
|
+
|
1142
|
+
def _get_dependencies(self) -> List[str]:
|
1143
|
+
return self._deps
|
1144
|
+
|
1145
|
+
|
1146
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1044
1147
|
self._model_signature_dict = dict()
|
1045
1148
|
|
1046
1149
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1047
1150
|
|
1048
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1151
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1049
1152
|
outputs: List[BaseFeatureSpec] = []
|
1050
1153
|
if hasattr(self, "predict"):
|
1051
1154
|
# keep mypy happy
|
1052
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1155
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1053
1156
|
# For classifier, the type of predict is the same as the type of label
|
1054
|
-
if self._sklearn_object._estimator_type ==
|
1055
|
-
|
1157
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1158
|
+
# label columns is the desired type for output
|
1056
1159
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1057
1160
|
# rename the output columns
|
1058
1161
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1059
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1060
|
-
|
1061
|
-
|
1162
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1163
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1164
|
+
)
|
1062
1165
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1063
1166
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1064
|
-
# Clusterer returns int64 cluster labels.
|
1167
|
+
# Clusterer returns int64 cluster labels.
|
1065
1168
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1066
1169
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1067
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1170
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1171
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1172
|
+
)
|
1173
|
+
|
1071
1174
|
# For regressor, the type of predict is float64
|
1072
|
-
elif self._sklearn_object._estimator_type ==
|
1175
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1073
1176
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1074
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1177
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1178
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1179
|
+
)
|
1180
|
+
|
1078
1181
|
for prob_func in PROB_FUNCTIONS:
|
1079
1182
|
if hasattr(self, prob_func):
|
1080
1183
|
output_cols_prefix: str = f"{prob_func}_"
|
1081
1184
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1082
1185
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1083
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1084
|
-
|
1085
|
-
|
1186
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1187
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1188
|
+
)
|
1086
1189
|
|
1087
1190
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1088
1191
|
items = list(self._model_signature_dict.items())
|
@@ -1095,10 +1198,10 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
1095
1198
|
"""Returns model signature of current class.
|
1096
1199
|
|
1097
1200
|
Raises:
|
1098
|
-
|
1201
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1099
1202
|
|
1100
1203
|
Returns:
|
1101
|
-
Dict
|
1204
|
+
Dict with each method and its input output signature
|
1102
1205
|
"""
|
1103
1206
|
if self._model_signature_dict is None:
|
1104
1207
|
raise exceptions.SnowflakeMLException(
|
@@ -1106,35 +1209,3 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
1106
1209
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1107
1210
|
)
|
1108
1211
|
return self._model_signature_dict
|
1109
|
-
|
1110
|
-
def to_sklearn(self) -> Any:
|
1111
|
-
"""Get sklearn.mixture.BayesianGaussianMixture object.
|
1112
|
-
"""
|
1113
|
-
if self._sklearn_object is None:
|
1114
|
-
self._sklearn_object = self._create_sklearn_object()
|
1115
|
-
return self._sklearn_object
|
1116
|
-
|
1117
|
-
def to_xgboost(self) -> Any:
|
1118
|
-
raise exceptions.SnowflakeMLException(
|
1119
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1120
|
-
original_exception=AttributeError(
|
1121
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1122
|
-
"to_xgboost()",
|
1123
|
-
"to_sklearn()"
|
1124
|
-
)
|
1125
|
-
),
|
1126
|
-
)
|
1127
|
-
|
1128
|
-
def to_lightgbm(self) -> Any:
|
1129
|
-
raise exceptions.SnowflakeMLException(
|
1130
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1131
|
-
original_exception=AttributeError(
|
1132
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1133
|
-
"to_lightgbm()",
|
1134
|
-
"to_sklearn()"
|
1135
|
-
)
|
1136
|
-
),
|
1137
|
-
)
|
1138
|
-
|
1139
|
-
def _get_dependencies(self) -> List[str]:
|
1140
|
-
return self._deps
|