snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class GaussianNB(BaseTransformer):
|
71
64
|
r"""Gaussian Naive Bayes (GaussianNB)
|
72
65
|
For more details on this class, see [sklearn.naive_bayes.GaussianNB]
|
@@ -203,12 +196,7 @@ class GaussianNB(BaseTransformer):
|
|
203
196
|
)
|
204
197
|
return selected_cols
|
205
198
|
|
206
|
-
|
207
|
-
project=_PROJECT,
|
208
|
-
subproject=_SUBPROJECT,
|
209
|
-
custom_tags=dict([("autogen", True)]),
|
210
|
-
)
|
211
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "GaussianNB":
|
199
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "GaussianNB":
|
212
200
|
"""Fit Gaussian Naive Bayes according to X, y
|
213
201
|
For more details on this function, see [sklearn.naive_bayes.GaussianNB.fit]
|
214
202
|
(https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html#sklearn.naive_bayes.GaussianNB.fit)
|
@@ -235,12 +223,14 @@ class GaussianNB(BaseTransformer):
|
|
235
223
|
|
236
224
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
237
225
|
|
238
|
-
|
226
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
239
227
|
if SNOWML_SPROC_ENV in os.environ:
|
240
228
|
statement_params = telemetry.get_function_usage_statement_params(
|
241
229
|
project=_PROJECT,
|
242
230
|
subproject=_SUBPROJECT,
|
243
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
231
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
232
|
+
inspect.currentframe(), GaussianNB.__class__.__name__
|
233
|
+
),
|
244
234
|
api_calls=[Session.call],
|
245
235
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
246
236
|
)
|
@@ -261,27 +251,24 @@ class GaussianNB(BaseTransformer):
|
|
261
251
|
)
|
262
252
|
self._sklearn_object = model_trainer.train()
|
263
253
|
self._is_fitted = True
|
264
|
-
self.
|
254
|
+
self._generate_model_signatures(dataset)
|
265
255
|
return self
|
266
256
|
|
267
257
|
def _batch_inference_validate_snowpark(
|
268
258
|
self,
|
269
259
|
dataset: DataFrame,
|
270
260
|
inference_method: str,
|
271
|
-
) ->
|
272
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
273
|
-
return the available package that exists in the snowflake anaconda channel
|
261
|
+
) -> None:
|
262
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
274
263
|
|
275
264
|
Args:
|
276
265
|
dataset: snowpark dataframe
|
277
266
|
inference_method: the inference method such as predict, score...
|
278
|
-
|
267
|
+
|
279
268
|
Raises:
|
280
269
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
281
270
|
SnowflakeMLException: If the session is None, raise error
|
282
271
|
|
283
|
-
Returns:
|
284
|
-
A list of available package that exists in the snowflake anaconda channel
|
285
272
|
"""
|
286
273
|
if not self._is_fitted:
|
287
274
|
raise exceptions.SnowflakeMLException(
|
@@ -299,9 +286,7 @@ class GaussianNB(BaseTransformer):
|
|
299
286
|
"Session must not specified for snowpark dataset."
|
300
287
|
),
|
301
288
|
)
|
302
|
-
|
303
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
304
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
289
|
+
|
305
290
|
|
306
291
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
307
292
|
@telemetry.send_api_usage_telemetry(
|
@@ -337,7 +322,9 @@ class GaussianNB(BaseTransformer):
|
|
337
322
|
# when it is classifier, infer the datatype from label columns
|
338
323
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
339
324
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
340
|
-
label_cols_signatures = [
|
325
|
+
label_cols_signatures = [
|
326
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
327
|
+
]
|
341
328
|
if len(label_cols_signatures) == 0:
|
342
329
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
343
330
|
raise exceptions.SnowflakeMLException(
|
@@ -345,25 +332,23 @@ class GaussianNB(BaseTransformer):
|
|
345
332
|
original_exception=ValueError(error_str),
|
346
333
|
)
|
347
334
|
|
348
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
349
|
-
label_cols_signatures[0].as_snowpark_type()
|
350
|
-
)
|
335
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
351
336
|
|
352
|
-
self.
|
353
|
-
|
337
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
338
|
+
self._deps = self._get_dependencies()
|
339
|
+
assert isinstance(
|
340
|
+
dataset._session, Session
|
341
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
354
342
|
|
355
343
|
transform_kwargs = dict(
|
356
|
-
session
|
357
|
-
dependencies
|
358
|
-
drop_input_cols
|
359
|
-
expected_output_cols_type
|
344
|
+
session=dataset._session,
|
345
|
+
dependencies=self._deps,
|
346
|
+
drop_input_cols=self._drop_input_cols,
|
347
|
+
expected_output_cols_type=expected_type_inferred,
|
360
348
|
)
|
361
349
|
|
362
350
|
elif isinstance(dataset, pd.DataFrame):
|
363
|
-
transform_kwargs = dict(
|
364
|
-
snowpark_input_cols = self._snowpark_cols,
|
365
|
-
drop_input_cols = self._drop_input_cols
|
366
|
-
)
|
351
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
367
352
|
|
368
353
|
transform_handlers = ModelTransformerBuilder.build(
|
369
354
|
dataset=dataset,
|
@@ -403,7 +388,7 @@ class GaussianNB(BaseTransformer):
|
|
403
388
|
Transformed dataset.
|
404
389
|
"""
|
405
390
|
super()._check_dataset_type(dataset)
|
406
|
-
inference_method="transform"
|
391
|
+
inference_method = "transform"
|
407
392
|
|
408
393
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
409
394
|
# are specific to the type of dataset used.
|
@@ -433,24 +418,19 @@ class GaussianNB(BaseTransformer):
|
|
433
418
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
434
419
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
435
420
|
|
436
|
-
self.
|
437
|
-
|
438
|
-
inference_method=inference_method,
|
439
|
-
)
|
421
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
422
|
+
self._deps = self._get_dependencies()
|
440
423
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
441
424
|
|
442
425
|
transform_kwargs = dict(
|
443
|
-
session
|
444
|
-
dependencies
|
445
|
-
drop_input_cols
|
446
|
-
expected_output_cols_type
|
426
|
+
session=dataset._session,
|
427
|
+
dependencies=self._deps,
|
428
|
+
drop_input_cols=self._drop_input_cols,
|
429
|
+
expected_output_cols_type=expected_dtype,
|
447
430
|
)
|
448
431
|
|
449
432
|
elif isinstance(dataset, pd.DataFrame):
|
450
|
-
transform_kwargs = dict(
|
451
|
-
snowpark_input_cols = self._snowpark_cols,
|
452
|
-
drop_input_cols = self._drop_input_cols
|
453
|
-
)
|
433
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
454
434
|
|
455
435
|
transform_handlers = ModelTransformerBuilder.build(
|
456
436
|
dataset=dataset,
|
@@ -469,7 +449,11 @@ class GaussianNB(BaseTransformer):
|
|
469
449
|
return output_df
|
470
450
|
|
471
451
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
472
|
-
def fit_predict(
|
452
|
+
def fit_predict(
|
453
|
+
self,
|
454
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
455
|
+
output_cols_prefix: str = "fit_predict_",
|
456
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
473
457
|
""" Method not supported for this class.
|
474
458
|
|
475
459
|
|
@@ -494,22 +478,104 @@ class GaussianNB(BaseTransformer):
|
|
494
478
|
)
|
495
479
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
496
480
|
drop_input_cols=self._drop_input_cols,
|
497
|
-
expected_output_cols_list=
|
481
|
+
expected_output_cols_list=(
|
482
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
483
|
+
),
|
498
484
|
)
|
499
485
|
self._sklearn_object = fitted_estimator
|
500
486
|
self._is_fitted = True
|
501
487
|
return output_result
|
502
488
|
|
489
|
+
|
490
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
491
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
492
|
+
""" Method not supported for this class.
|
493
|
+
|
503
494
|
|
504
|
-
|
505
|
-
|
506
|
-
|
495
|
+
Raises:
|
496
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
497
|
+
|
498
|
+
Args:
|
499
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
500
|
+
Snowpark or Pandas DataFrame.
|
501
|
+
output_cols_prefix: Prefix for the response columns
|
507
502
|
Returns:
|
508
503
|
Transformed dataset.
|
509
504
|
"""
|
510
|
-
self.
|
511
|
-
|
512
|
-
|
505
|
+
self._infer_input_output_cols(dataset)
|
506
|
+
super()._check_dataset_type(dataset)
|
507
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
508
|
+
estimator=self._sklearn_object,
|
509
|
+
dataset=dataset,
|
510
|
+
input_cols=self.input_cols,
|
511
|
+
label_cols=self.label_cols,
|
512
|
+
sample_weight_col=self.sample_weight_col,
|
513
|
+
autogenerated=self._autogenerated,
|
514
|
+
subproject=_SUBPROJECT,
|
515
|
+
)
|
516
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
517
|
+
drop_input_cols=self._drop_input_cols,
|
518
|
+
expected_output_cols_list=self.output_cols,
|
519
|
+
)
|
520
|
+
self._sklearn_object = fitted_estimator
|
521
|
+
self._is_fitted = True
|
522
|
+
return output_result
|
523
|
+
|
524
|
+
|
525
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
526
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
527
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
528
|
+
"""
|
529
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
530
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
531
|
+
if output_cols:
|
532
|
+
output_cols = [
|
533
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
534
|
+
for c in output_cols
|
535
|
+
]
|
536
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
537
|
+
output_cols = [output_cols_prefix]
|
538
|
+
elif self._sklearn_object is not None:
|
539
|
+
classes = self._sklearn_object.classes_
|
540
|
+
if isinstance(classes, numpy.ndarray):
|
541
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
542
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
543
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
544
|
+
output_cols = []
|
545
|
+
for i, cl in enumerate(classes):
|
546
|
+
# For binary classification, there is only one output column for each class
|
547
|
+
# ndarray as the two classes are complementary.
|
548
|
+
if len(cl) == 2:
|
549
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
550
|
+
else:
|
551
|
+
output_cols.extend([
|
552
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
553
|
+
])
|
554
|
+
else:
|
555
|
+
output_cols = []
|
556
|
+
|
557
|
+
# Make sure column names are valid snowflake identifiers.
|
558
|
+
assert output_cols is not None # Make MyPy happy
|
559
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
560
|
+
|
561
|
+
return rv
|
562
|
+
|
563
|
+
def _align_expected_output_names(
|
564
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
565
|
+
) -> List[str]:
|
566
|
+
# in case the inferred output column names dimension is different
|
567
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
568
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
569
|
+
output_df_columns = list(output_df_pd.columns)
|
570
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
571
|
+
if self.sample_weight_col:
|
572
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
573
|
+
# if the dimension of inferred output column names is correct; use it
|
574
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
575
|
+
return expected_output_cols_list
|
576
|
+
# otherwise, use the sklearn estimator's output
|
577
|
+
else:
|
578
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
513
579
|
|
514
580
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
515
581
|
@telemetry.send_api_usage_telemetry(
|
@@ -543,24 +609,26 @@ class GaussianNB(BaseTransformer):
|
|
543
609
|
# are specific to the type of dataset used.
|
544
610
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
545
611
|
|
612
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
613
|
+
|
546
614
|
if isinstance(dataset, DataFrame):
|
547
|
-
self.
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
615
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
616
|
+
self._deps = self._get_dependencies()
|
617
|
+
assert isinstance(
|
618
|
+
dataset._session, Session
|
619
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
552
620
|
transform_kwargs = dict(
|
553
621
|
session=dataset._session,
|
554
622
|
dependencies=self._deps,
|
555
|
-
drop_input_cols
|
623
|
+
drop_input_cols=self._drop_input_cols,
|
556
624
|
expected_output_cols_type="float",
|
557
625
|
)
|
626
|
+
expected_output_cols = self._align_expected_output_names(
|
627
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
628
|
+
)
|
558
629
|
|
559
630
|
elif isinstance(dataset, pd.DataFrame):
|
560
|
-
transform_kwargs = dict(
|
561
|
-
snowpark_input_cols = self._snowpark_cols,
|
562
|
-
drop_input_cols = self._drop_input_cols
|
563
|
-
)
|
631
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
564
632
|
|
565
633
|
transform_handlers = ModelTransformerBuilder.build(
|
566
634
|
dataset=dataset,
|
@@ -572,7 +640,7 @@ class GaussianNB(BaseTransformer):
|
|
572
640
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
573
641
|
inference_method=inference_method,
|
574
642
|
input_cols=self.input_cols,
|
575
|
-
expected_output_cols=
|
643
|
+
expected_output_cols=expected_output_cols,
|
576
644
|
**transform_kwargs
|
577
645
|
)
|
578
646
|
return output_df
|
@@ -604,29 +672,30 @@ class GaussianNB(BaseTransformer):
|
|
604
672
|
Output dataset with log probability of the sample for each class in the model.
|
605
673
|
"""
|
606
674
|
super()._check_dataset_type(dataset)
|
607
|
-
inference_method="predict_log_proba"
|
675
|
+
inference_method = "predict_log_proba"
|
676
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
608
677
|
|
609
678
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
610
679
|
# are specific to the type of dataset used.
|
611
680
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
612
681
|
|
613
682
|
if isinstance(dataset, DataFrame):
|
614
|
-
self.
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
683
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
684
|
+
self._deps = self._get_dependencies()
|
685
|
+
assert isinstance(
|
686
|
+
dataset._session, Session
|
687
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
619
688
|
transform_kwargs = dict(
|
620
689
|
session=dataset._session,
|
621
690
|
dependencies=self._deps,
|
622
|
-
drop_input_cols
|
691
|
+
drop_input_cols=self._drop_input_cols,
|
623
692
|
expected_output_cols_type="float",
|
624
693
|
)
|
694
|
+
expected_output_cols = self._align_expected_output_names(
|
695
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
696
|
+
)
|
625
697
|
elif isinstance(dataset, pd.DataFrame):
|
626
|
-
transform_kwargs = dict(
|
627
|
-
snowpark_input_cols = self._snowpark_cols,
|
628
|
-
drop_input_cols = self._drop_input_cols
|
629
|
-
)
|
698
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
630
699
|
|
631
700
|
transform_handlers = ModelTransformerBuilder.build(
|
632
701
|
dataset=dataset,
|
@@ -639,7 +708,7 @@ class GaussianNB(BaseTransformer):
|
|
639
708
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
640
709
|
inference_method=inference_method,
|
641
710
|
input_cols=self.input_cols,
|
642
|
-
expected_output_cols=
|
711
|
+
expected_output_cols=expected_output_cols,
|
643
712
|
**transform_kwargs
|
644
713
|
)
|
645
714
|
return output_df
|
@@ -665,30 +734,32 @@ class GaussianNB(BaseTransformer):
|
|
665
734
|
Output dataset with results of the decision function for the samples in input dataset.
|
666
735
|
"""
|
667
736
|
super()._check_dataset_type(dataset)
|
668
|
-
inference_method="decision_function"
|
737
|
+
inference_method = "decision_function"
|
669
738
|
|
670
739
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
671
740
|
# are specific to the type of dataset used.
|
672
741
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
673
742
|
|
743
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
744
|
+
|
674
745
|
if isinstance(dataset, DataFrame):
|
675
|
-
self.
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
746
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
747
|
+
self._deps = self._get_dependencies()
|
748
|
+
assert isinstance(
|
749
|
+
dataset._session, Session
|
750
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
680
751
|
transform_kwargs = dict(
|
681
752
|
session=dataset._session,
|
682
753
|
dependencies=self._deps,
|
683
|
-
drop_input_cols
|
754
|
+
drop_input_cols=self._drop_input_cols,
|
684
755
|
expected_output_cols_type="float",
|
685
756
|
)
|
757
|
+
expected_output_cols = self._align_expected_output_names(
|
758
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
759
|
+
)
|
686
760
|
|
687
761
|
elif isinstance(dataset, pd.DataFrame):
|
688
|
-
transform_kwargs = dict(
|
689
|
-
snowpark_input_cols = self._snowpark_cols,
|
690
|
-
drop_input_cols = self._drop_input_cols
|
691
|
-
)
|
762
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
692
763
|
|
693
764
|
transform_handlers = ModelTransformerBuilder.build(
|
694
765
|
dataset=dataset,
|
@@ -701,7 +772,7 @@ class GaussianNB(BaseTransformer):
|
|
701
772
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
702
773
|
inference_method=inference_method,
|
703
774
|
input_cols=self.input_cols,
|
704
|
-
expected_output_cols=
|
775
|
+
expected_output_cols=expected_output_cols,
|
705
776
|
**transform_kwargs
|
706
777
|
)
|
707
778
|
return output_df
|
@@ -730,17 +801,17 @@ class GaussianNB(BaseTransformer):
|
|
730
801
|
Output dataset with probability of the sample for each class in the model.
|
731
802
|
"""
|
732
803
|
super()._check_dataset_type(dataset)
|
733
|
-
inference_method="score_samples"
|
804
|
+
inference_method = "score_samples"
|
734
805
|
|
735
806
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
736
807
|
# are specific to the type of dataset used.
|
737
808
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
738
809
|
|
810
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
811
|
+
|
739
812
|
if isinstance(dataset, DataFrame):
|
740
|
-
self.
|
741
|
-
|
742
|
-
inference_method=inference_method,
|
743
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
744
815
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
745
816
|
transform_kwargs = dict(
|
746
817
|
session=dataset._session,
|
@@ -748,6 +819,9 @@ class GaussianNB(BaseTransformer):
|
|
748
819
|
drop_input_cols = self._drop_input_cols,
|
749
820
|
expected_output_cols_type="float",
|
750
821
|
)
|
822
|
+
expected_output_cols = self._align_expected_output_names(
|
823
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
824
|
+
)
|
751
825
|
|
752
826
|
elif isinstance(dataset, pd.DataFrame):
|
753
827
|
transform_kwargs = dict(
|
@@ -766,7 +840,7 @@ class GaussianNB(BaseTransformer):
|
|
766
840
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
767
841
|
inference_method=inference_method,
|
768
842
|
input_cols=self.input_cols,
|
769
|
-
expected_output_cols=
|
843
|
+
expected_output_cols=expected_output_cols,
|
770
844
|
**transform_kwargs
|
771
845
|
)
|
772
846
|
return output_df
|
@@ -801,17 +875,15 @@ class GaussianNB(BaseTransformer):
|
|
801
875
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
802
876
|
|
803
877
|
if isinstance(dataset, DataFrame):
|
804
|
-
self.
|
805
|
-
|
806
|
-
inference_method="score",
|
807
|
-
)
|
878
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
879
|
+
self._deps = self._get_dependencies()
|
808
880
|
selected_cols = self._get_active_columns()
|
809
881
|
if len(selected_cols) > 0:
|
810
882
|
dataset = dataset.select(selected_cols)
|
811
883
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
812
884
|
transform_kwargs = dict(
|
813
885
|
session=dataset._session,
|
814
|
-
dependencies=
|
886
|
+
dependencies=self._deps,
|
815
887
|
score_sproc_imports=['sklearn'],
|
816
888
|
)
|
817
889
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -876,11 +948,8 @@ class GaussianNB(BaseTransformer):
|
|
876
948
|
|
877
949
|
if isinstance(dataset, DataFrame):
|
878
950
|
|
879
|
-
self.
|
880
|
-
|
881
|
-
inference_method=inference_method,
|
882
|
-
|
883
|
-
)
|
951
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
952
|
+
self._deps = self._get_dependencies()
|
884
953
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
885
954
|
transform_kwargs = dict(
|
886
955
|
session = dataset._session,
|
@@ -913,50 +982,84 @@ class GaussianNB(BaseTransformer):
|
|
913
982
|
)
|
914
983
|
return output_df
|
915
984
|
|
985
|
+
|
986
|
+
|
987
|
+
def to_sklearn(self) -> Any:
|
988
|
+
"""Get sklearn.naive_bayes.GaussianNB object.
|
989
|
+
"""
|
990
|
+
if self._sklearn_object is None:
|
991
|
+
self._sklearn_object = self._create_sklearn_object()
|
992
|
+
return self._sklearn_object
|
993
|
+
|
994
|
+
def to_xgboost(self) -> Any:
|
995
|
+
raise exceptions.SnowflakeMLException(
|
996
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
997
|
+
original_exception=AttributeError(
|
998
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
999
|
+
"to_xgboost()",
|
1000
|
+
"to_sklearn()"
|
1001
|
+
)
|
1002
|
+
),
|
1003
|
+
)
|
1004
|
+
|
1005
|
+
def to_lightgbm(self) -> Any:
|
1006
|
+
raise exceptions.SnowflakeMLException(
|
1007
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1008
|
+
original_exception=AttributeError(
|
1009
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1010
|
+
"to_lightgbm()",
|
1011
|
+
"to_sklearn()"
|
1012
|
+
)
|
1013
|
+
),
|
1014
|
+
)
|
1015
|
+
|
1016
|
+
def _get_dependencies(self) -> List[str]:
|
1017
|
+
return self._deps
|
1018
|
+
|
916
1019
|
|
917
|
-
def
|
1020
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
918
1021
|
self._model_signature_dict = dict()
|
919
1022
|
|
920
1023
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
921
1024
|
|
922
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1025
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
923
1026
|
outputs: List[BaseFeatureSpec] = []
|
924
1027
|
if hasattr(self, "predict"):
|
925
1028
|
# keep mypy happy
|
926
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1029
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
927
1030
|
# For classifier, the type of predict is the same as the type of label
|
928
|
-
if self._sklearn_object._estimator_type ==
|
929
|
-
|
1031
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1032
|
+
# label columns is the desired type for output
|
930
1033
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
931
1034
|
# rename the output columns
|
932
1035
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
933
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
934
|
-
|
935
|
-
|
1036
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1037
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1038
|
+
)
|
936
1039
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
937
1040
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
938
|
-
# Clusterer returns int64 cluster labels.
|
1041
|
+
# Clusterer returns int64 cluster labels.
|
939
1042
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
940
1043
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
941
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
942
|
-
|
943
|
-
|
944
|
-
|
1044
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1045
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1046
|
+
)
|
1047
|
+
|
945
1048
|
# For regressor, the type of predict is float64
|
946
|
-
elif self._sklearn_object._estimator_type ==
|
1049
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
947
1050
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
948
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
949
|
-
|
950
|
-
|
951
|
-
|
1051
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1052
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1053
|
+
)
|
1054
|
+
|
952
1055
|
for prob_func in PROB_FUNCTIONS:
|
953
1056
|
if hasattr(self, prob_func):
|
954
1057
|
output_cols_prefix: str = f"{prob_func}_"
|
955
1058
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
956
1059
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
957
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
958
|
-
|
959
|
-
|
1060
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1061
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1062
|
+
)
|
960
1063
|
|
961
1064
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
962
1065
|
items = list(self._model_signature_dict.items())
|
@@ -969,10 +1072,10 @@ class GaussianNB(BaseTransformer):
|
|
969
1072
|
"""Returns model signature of current class.
|
970
1073
|
|
971
1074
|
Raises:
|
972
|
-
|
1075
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
973
1076
|
|
974
1077
|
Returns:
|
975
|
-
Dict
|
1078
|
+
Dict with each method and its input output signature
|
976
1079
|
"""
|
977
1080
|
if self._model_signature_dict is None:
|
978
1081
|
raise exceptions.SnowflakeMLException(
|
@@ -980,35 +1083,3 @@ class GaussianNB(BaseTransformer):
|
|
980
1083
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
981
1084
|
)
|
982
1085
|
return self._model_signature_dict
|
983
|
-
|
984
|
-
def to_sklearn(self) -> Any:
|
985
|
-
"""Get sklearn.naive_bayes.GaussianNB object.
|
986
|
-
"""
|
987
|
-
if self._sklearn_object is None:
|
988
|
-
self._sklearn_object = self._create_sklearn_object()
|
989
|
-
return self._sklearn_object
|
990
|
-
|
991
|
-
def to_xgboost(self) -> Any:
|
992
|
-
raise exceptions.SnowflakeMLException(
|
993
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
994
|
-
original_exception=AttributeError(
|
995
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
996
|
-
"to_xgboost()",
|
997
|
-
"to_sklearn()"
|
998
|
-
)
|
999
|
-
),
|
1000
|
-
)
|
1001
|
-
|
1002
|
-
def to_lightgbm(self) -> Any:
|
1003
|
-
raise exceptions.SnowflakeMLException(
|
1004
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1005
|
-
original_exception=AttributeError(
|
1006
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1007
|
-
"to_lightgbm()",
|
1008
|
-
"to_sklearn()"
|
1009
|
-
)
|
1010
|
-
),
|
1011
|
-
)
|
1012
|
-
|
1013
|
-
def _get_dependencies(self) -> List[str]:
|
1014
|
-
return self._deps
|