snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sk
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class ComplementNB(BaseTransformer):
|
71
64
|
r"""The Complement Naive Bayes classifier described in Rennie et al
|
72
65
|
For more details on this class, see [sklearn.naive_bayes.ComplementNB]
|
@@ -222,12 +215,7 @@ class ComplementNB(BaseTransformer):
|
|
222
215
|
)
|
223
216
|
return selected_cols
|
224
217
|
|
225
|
-
|
226
|
-
project=_PROJECT,
|
227
|
-
subproject=_SUBPROJECT,
|
228
|
-
custom_tags=dict([("autogen", True)]),
|
229
|
-
)
|
230
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ComplementNB":
|
218
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "ComplementNB":
|
231
219
|
"""Fit Naive Bayes classifier according to X, y
|
232
220
|
For more details on this function, see [sklearn.naive_bayes.ComplementNB.fit]
|
233
221
|
(https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.ComplementNB.html#sklearn.naive_bayes.ComplementNB.fit)
|
@@ -254,12 +242,14 @@ class ComplementNB(BaseTransformer):
|
|
254
242
|
|
255
243
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
256
244
|
|
257
|
-
|
245
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
258
246
|
if SNOWML_SPROC_ENV in os.environ:
|
259
247
|
statement_params = telemetry.get_function_usage_statement_params(
|
260
248
|
project=_PROJECT,
|
261
249
|
subproject=_SUBPROJECT,
|
262
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
250
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
251
|
+
inspect.currentframe(), ComplementNB.__class__.__name__
|
252
|
+
),
|
263
253
|
api_calls=[Session.call],
|
264
254
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
265
255
|
)
|
@@ -280,27 +270,24 @@ class ComplementNB(BaseTransformer):
|
|
280
270
|
)
|
281
271
|
self._sklearn_object = model_trainer.train()
|
282
272
|
self._is_fitted = True
|
283
|
-
self.
|
273
|
+
self._generate_model_signatures(dataset)
|
284
274
|
return self
|
285
275
|
|
286
276
|
def _batch_inference_validate_snowpark(
|
287
277
|
self,
|
288
278
|
dataset: DataFrame,
|
289
279
|
inference_method: str,
|
290
|
-
) ->
|
291
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
292
|
-
return the available package that exists in the snowflake anaconda channel
|
280
|
+
) -> None:
|
281
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
293
282
|
|
294
283
|
Args:
|
295
284
|
dataset: snowpark dataframe
|
296
285
|
inference_method: the inference method such as predict, score...
|
297
|
-
|
286
|
+
|
298
287
|
Raises:
|
299
288
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
300
289
|
SnowflakeMLException: If the session is None, raise error
|
301
290
|
|
302
|
-
Returns:
|
303
|
-
A list of available package that exists in the snowflake anaconda channel
|
304
291
|
"""
|
305
292
|
if not self._is_fitted:
|
306
293
|
raise exceptions.SnowflakeMLException(
|
@@ -318,9 +305,7 @@ class ComplementNB(BaseTransformer):
|
|
318
305
|
"Session must not specified for snowpark dataset."
|
319
306
|
),
|
320
307
|
)
|
321
|
-
|
322
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
323
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
308
|
+
|
324
309
|
|
325
310
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
326
311
|
@telemetry.send_api_usage_telemetry(
|
@@ -356,7 +341,9 @@ class ComplementNB(BaseTransformer):
|
|
356
341
|
# when it is classifier, infer the datatype from label columns
|
357
342
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
358
343
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
359
|
-
label_cols_signatures = [
|
344
|
+
label_cols_signatures = [
|
345
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
346
|
+
]
|
360
347
|
if len(label_cols_signatures) == 0:
|
361
348
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
362
349
|
raise exceptions.SnowflakeMLException(
|
@@ -364,25 +351,23 @@ class ComplementNB(BaseTransformer):
|
|
364
351
|
original_exception=ValueError(error_str),
|
365
352
|
)
|
366
353
|
|
367
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
368
|
-
label_cols_signatures[0].as_snowpark_type()
|
369
|
-
)
|
354
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
370
355
|
|
371
|
-
self.
|
372
|
-
|
356
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
357
|
+
self._deps = self._get_dependencies()
|
358
|
+
assert isinstance(
|
359
|
+
dataset._session, Session
|
360
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
373
361
|
|
374
362
|
transform_kwargs = dict(
|
375
|
-
session
|
376
|
-
dependencies
|
377
|
-
drop_input_cols
|
378
|
-
expected_output_cols_type
|
363
|
+
session=dataset._session,
|
364
|
+
dependencies=self._deps,
|
365
|
+
drop_input_cols=self._drop_input_cols,
|
366
|
+
expected_output_cols_type=expected_type_inferred,
|
379
367
|
)
|
380
368
|
|
381
369
|
elif isinstance(dataset, pd.DataFrame):
|
382
|
-
transform_kwargs = dict(
|
383
|
-
snowpark_input_cols = self._snowpark_cols,
|
384
|
-
drop_input_cols = self._drop_input_cols
|
385
|
-
)
|
370
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
386
371
|
|
387
372
|
transform_handlers = ModelTransformerBuilder.build(
|
388
373
|
dataset=dataset,
|
@@ -422,7 +407,7 @@ class ComplementNB(BaseTransformer):
|
|
422
407
|
Transformed dataset.
|
423
408
|
"""
|
424
409
|
super()._check_dataset_type(dataset)
|
425
|
-
inference_method="transform"
|
410
|
+
inference_method = "transform"
|
426
411
|
|
427
412
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
428
413
|
# are specific to the type of dataset used.
|
@@ -452,24 +437,19 @@ class ComplementNB(BaseTransformer):
|
|
452
437
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
453
438
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
454
439
|
|
455
|
-
self.
|
456
|
-
|
457
|
-
inference_method=inference_method,
|
458
|
-
)
|
440
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
441
|
+
self._deps = self._get_dependencies()
|
459
442
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
460
443
|
|
461
444
|
transform_kwargs = dict(
|
462
|
-
session
|
463
|
-
dependencies
|
464
|
-
drop_input_cols
|
465
|
-
expected_output_cols_type
|
445
|
+
session=dataset._session,
|
446
|
+
dependencies=self._deps,
|
447
|
+
drop_input_cols=self._drop_input_cols,
|
448
|
+
expected_output_cols_type=expected_dtype,
|
466
449
|
)
|
467
450
|
|
468
451
|
elif isinstance(dataset, pd.DataFrame):
|
469
|
-
transform_kwargs = dict(
|
470
|
-
snowpark_input_cols = self._snowpark_cols,
|
471
|
-
drop_input_cols = self._drop_input_cols
|
472
|
-
)
|
452
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
473
453
|
|
474
454
|
transform_handlers = ModelTransformerBuilder.build(
|
475
455
|
dataset=dataset,
|
@@ -488,7 +468,11 @@ class ComplementNB(BaseTransformer):
|
|
488
468
|
return output_df
|
489
469
|
|
490
470
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
491
|
-
def fit_predict(
|
471
|
+
def fit_predict(
|
472
|
+
self,
|
473
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
474
|
+
output_cols_prefix: str = "fit_predict_",
|
475
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
492
476
|
""" Method not supported for this class.
|
493
477
|
|
494
478
|
|
@@ -513,22 +497,104 @@ class ComplementNB(BaseTransformer):
|
|
513
497
|
)
|
514
498
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
515
499
|
drop_input_cols=self._drop_input_cols,
|
516
|
-
expected_output_cols_list=
|
500
|
+
expected_output_cols_list=(
|
501
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
502
|
+
),
|
517
503
|
)
|
518
504
|
self._sklearn_object = fitted_estimator
|
519
505
|
self._is_fitted = True
|
520
506
|
return output_result
|
521
507
|
|
508
|
+
|
509
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
510
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
511
|
+
""" Method not supported for this class.
|
512
|
+
|
522
513
|
|
523
|
-
|
524
|
-
|
525
|
-
|
514
|
+
Raises:
|
515
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
516
|
+
|
517
|
+
Args:
|
518
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
519
|
+
Snowpark or Pandas DataFrame.
|
520
|
+
output_cols_prefix: Prefix for the response columns
|
526
521
|
Returns:
|
527
522
|
Transformed dataset.
|
528
523
|
"""
|
529
|
-
self.
|
530
|
-
|
531
|
-
|
524
|
+
self._infer_input_output_cols(dataset)
|
525
|
+
super()._check_dataset_type(dataset)
|
526
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
527
|
+
estimator=self._sklearn_object,
|
528
|
+
dataset=dataset,
|
529
|
+
input_cols=self.input_cols,
|
530
|
+
label_cols=self.label_cols,
|
531
|
+
sample_weight_col=self.sample_weight_col,
|
532
|
+
autogenerated=self._autogenerated,
|
533
|
+
subproject=_SUBPROJECT,
|
534
|
+
)
|
535
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
536
|
+
drop_input_cols=self._drop_input_cols,
|
537
|
+
expected_output_cols_list=self.output_cols,
|
538
|
+
)
|
539
|
+
self._sklearn_object = fitted_estimator
|
540
|
+
self._is_fitted = True
|
541
|
+
return output_result
|
542
|
+
|
543
|
+
|
544
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
545
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
546
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
547
|
+
"""
|
548
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
549
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
550
|
+
if output_cols:
|
551
|
+
output_cols = [
|
552
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
553
|
+
for c in output_cols
|
554
|
+
]
|
555
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
556
|
+
output_cols = [output_cols_prefix]
|
557
|
+
elif self._sklearn_object is not None:
|
558
|
+
classes = self._sklearn_object.classes_
|
559
|
+
if isinstance(classes, numpy.ndarray):
|
560
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
561
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
562
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
563
|
+
output_cols = []
|
564
|
+
for i, cl in enumerate(classes):
|
565
|
+
# For binary classification, there is only one output column for each class
|
566
|
+
# ndarray as the two classes are complementary.
|
567
|
+
if len(cl) == 2:
|
568
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
569
|
+
else:
|
570
|
+
output_cols.extend([
|
571
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
572
|
+
])
|
573
|
+
else:
|
574
|
+
output_cols = []
|
575
|
+
|
576
|
+
# Make sure column names are valid snowflake identifiers.
|
577
|
+
assert output_cols is not None # Make MyPy happy
|
578
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
579
|
+
|
580
|
+
return rv
|
581
|
+
|
582
|
+
def _align_expected_output_names(
|
583
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
584
|
+
) -> List[str]:
|
585
|
+
# in case the inferred output column names dimension is different
|
586
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
587
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
588
|
+
output_df_columns = list(output_df_pd.columns)
|
589
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
590
|
+
if self.sample_weight_col:
|
591
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
592
|
+
# if the dimension of inferred output column names is correct; use it
|
593
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
594
|
+
return expected_output_cols_list
|
595
|
+
# otherwise, use the sklearn estimator's output
|
596
|
+
else:
|
597
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
532
598
|
|
533
599
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
534
600
|
@telemetry.send_api_usage_telemetry(
|
@@ -562,24 +628,26 @@ class ComplementNB(BaseTransformer):
|
|
562
628
|
# are specific to the type of dataset used.
|
563
629
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
564
630
|
|
631
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
632
|
+
|
565
633
|
if isinstance(dataset, DataFrame):
|
566
|
-
self.
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
634
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
635
|
+
self._deps = self._get_dependencies()
|
636
|
+
assert isinstance(
|
637
|
+
dataset._session, Session
|
638
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
571
639
|
transform_kwargs = dict(
|
572
640
|
session=dataset._session,
|
573
641
|
dependencies=self._deps,
|
574
|
-
drop_input_cols
|
642
|
+
drop_input_cols=self._drop_input_cols,
|
575
643
|
expected_output_cols_type="float",
|
576
644
|
)
|
645
|
+
expected_output_cols = self._align_expected_output_names(
|
646
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
647
|
+
)
|
577
648
|
|
578
649
|
elif isinstance(dataset, pd.DataFrame):
|
579
|
-
transform_kwargs = dict(
|
580
|
-
snowpark_input_cols = self._snowpark_cols,
|
581
|
-
drop_input_cols = self._drop_input_cols
|
582
|
-
)
|
650
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
583
651
|
|
584
652
|
transform_handlers = ModelTransformerBuilder.build(
|
585
653
|
dataset=dataset,
|
@@ -591,7 +659,7 @@ class ComplementNB(BaseTransformer):
|
|
591
659
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
592
660
|
inference_method=inference_method,
|
593
661
|
input_cols=self.input_cols,
|
594
|
-
expected_output_cols=
|
662
|
+
expected_output_cols=expected_output_cols,
|
595
663
|
**transform_kwargs
|
596
664
|
)
|
597
665
|
return output_df
|
@@ -623,29 +691,30 @@ class ComplementNB(BaseTransformer):
|
|
623
691
|
Output dataset with log probability of the sample for each class in the model.
|
624
692
|
"""
|
625
693
|
super()._check_dataset_type(dataset)
|
626
|
-
inference_method="predict_log_proba"
|
694
|
+
inference_method = "predict_log_proba"
|
695
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
627
696
|
|
628
697
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
629
698
|
# are specific to the type of dataset used.
|
630
699
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
631
700
|
|
632
701
|
if isinstance(dataset, DataFrame):
|
633
|
-
self.
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
702
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
703
|
+
self._deps = self._get_dependencies()
|
704
|
+
assert isinstance(
|
705
|
+
dataset._session, Session
|
706
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
638
707
|
transform_kwargs = dict(
|
639
708
|
session=dataset._session,
|
640
709
|
dependencies=self._deps,
|
641
|
-
drop_input_cols
|
710
|
+
drop_input_cols=self._drop_input_cols,
|
642
711
|
expected_output_cols_type="float",
|
643
712
|
)
|
713
|
+
expected_output_cols = self._align_expected_output_names(
|
714
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
715
|
+
)
|
644
716
|
elif isinstance(dataset, pd.DataFrame):
|
645
|
-
transform_kwargs = dict(
|
646
|
-
snowpark_input_cols = self._snowpark_cols,
|
647
|
-
drop_input_cols = self._drop_input_cols
|
648
|
-
)
|
717
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
649
718
|
|
650
719
|
transform_handlers = ModelTransformerBuilder.build(
|
651
720
|
dataset=dataset,
|
@@ -658,7 +727,7 @@ class ComplementNB(BaseTransformer):
|
|
658
727
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
659
728
|
inference_method=inference_method,
|
660
729
|
input_cols=self.input_cols,
|
661
|
-
expected_output_cols=
|
730
|
+
expected_output_cols=expected_output_cols,
|
662
731
|
**transform_kwargs
|
663
732
|
)
|
664
733
|
return output_df
|
@@ -684,30 +753,32 @@ class ComplementNB(BaseTransformer):
|
|
684
753
|
Output dataset with results of the decision function for the samples in input dataset.
|
685
754
|
"""
|
686
755
|
super()._check_dataset_type(dataset)
|
687
|
-
inference_method="decision_function"
|
756
|
+
inference_method = "decision_function"
|
688
757
|
|
689
758
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
690
759
|
# are specific to the type of dataset used.
|
691
760
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
692
761
|
|
762
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
763
|
+
|
693
764
|
if isinstance(dataset, DataFrame):
|
694
|
-
self.
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
765
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
766
|
+
self._deps = self._get_dependencies()
|
767
|
+
assert isinstance(
|
768
|
+
dataset._session, Session
|
769
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
699
770
|
transform_kwargs = dict(
|
700
771
|
session=dataset._session,
|
701
772
|
dependencies=self._deps,
|
702
|
-
drop_input_cols
|
773
|
+
drop_input_cols=self._drop_input_cols,
|
703
774
|
expected_output_cols_type="float",
|
704
775
|
)
|
776
|
+
expected_output_cols = self._align_expected_output_names(
|
777
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
778
|
+
)
|
705
779
|
|
706
780
|
elif isinstance(dataset, pd.DataFrame):
|
707
|
-
transform_kwargs = dict(
|
708
|
-
snowpark_input_cols = self._snowpark_cols,
|
709
|
-
drop_input_cols = self._drop_input_cols
|
710
|
-
)
|
781
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
711
782
|
|
712
783
|
transform_handlers = ModelTransformerBuilder.build(
|
713
784
|
dataset=dataset,
|
@@ -720,7 +791,7 @@ class ComplementNB(BaseTransformer):
|
|
720
791
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
721
792
|
inference_method=inference_method,
|
722
793
|
input_cols=self.input_cols,
|
723
|
-
expected_output_cols=
|
794
|
+
expected_output_cols=expected_output_cols,
|
724
795
|
**transform_kwargs
|
725
796
|
)
|
726
797
|
return output_df
|
@@ -749,17 +820,17 @@ class ComplementNB(BaseTransformer):
|
|
749
820
|
Output dataset with probability of the sample for each class in the model.
|
750
821
|
"""
|
751
822
|
super()._check_dataset_type(dataset)
|
752
|
-
inference_method="score_samples"
|
823
|
+
inference_method = "score_samples"
|
753
824
|
|
754
825
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
755
826
|
# are specific to the type of dataset used.
|
756
827
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
757
828
|
|
829
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
830
|
+
|
758
831
|
if isinstance(dataset, DataFrame):
|
759
|
-
self.
|
760
|
-
|
761
|
-
inference_method=inference_method,
|
762
|
-
)
|
832
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
833
|
+
self._deps = self._get_dependencies()
|
763
834
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
764
835
|
transform_kwargs = dict(
|
765
836
|
session=dataset._session,
|
@@ -767,6 +838,9 @@ class ComplementNB(BaseTransformer):
|
|
767
838
|
drop_input_cols = self._drop_input_cols,
|
768
839
|
expected_output_cols_type="float",
|
769
840
|
)
|
841
|
+
expected_output_cols = self._align_expected_output_names(
|
842
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
843
|
+
)
|
770
844
|
|
771
845
|
elif isinstance(dataset, pd.DataFrame):
|
772
846
|
transform_kwargs = dict(
|
@@ -785,7 +859,7 @@ class ComplementNB(BaseTransformer):
|
|
785
859
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
786
860
|
inference_method=inference_method,
|
787
861
|
input_cols=self.input_cols,
|
788
|
-
expected_output_cols=
|
862
|
+
expected_output_cols=expected_output_cols,
|
789
863
|
**transform_kwargs
|
790
864
|
)
|
791
865
|
return output_df
|
@@ -820,17 +894,15 @@ class ComplementNB(BaseTransformer):
|
|
820
894
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
821
895
|
|
822
896
|
if isinstance(dataset, DataFrame):
|
823
|
-
self.
|
824
|
-
|
825
|
-
inference_method="score",
|
826
|
-
)
|
897
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
898
|
+
self._deps = self._get_dependencies()
|
827
899
|
selected_cols = self._get_active_columns()
|
828
900
|
if len(selected_cols) > 0:
|
829
901
|
dataset = dataset.select(selected_cols)
|
830
902
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
831
903
|
transform_kwargs = dict(
|
832
904
|
session=dataset._session,
|
833
|
-
dependencies=
|
905
|
+
dependencies=self._deps,
|
834
906
|
score_sproc_imports=['sklearn'],
|
835
907
|
)
|
836
908
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -895,11 +967,8 @@ class ComplementNB(BaseTransformer):
|
|
895
967
|
|
896
968
|
if isinstance(dataset, DataFrame):
|
897
969
|
|
898
|
-
self.
|
899
|
-
|
900
|
-
inference_method=inference_method,
|
901
|
-
|
902
|
-
)
|
970
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
971
|
+
self._deps = self._get_dependencies()
|
903
972
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
904
973
|
transform_kwargs = dict(
|
905
974
|
session = dataset._session,
|
@@ -932,50 +1001,84 @@ class ComplementNB(BaseTransformer):
|
|
932
1001
|
)
|
933
1002
|
return output_df
|
934
1003
|
|
1004
|
+
|
1005
|
+
|
1006
|
+
def to_sklearn(self) -> Any:
|
1007
|
+
"""Get sklearn.naive_bayes.ComplementNB object.
|
1008
|
+
"""
|
1009
|
+
if self._sklearn_object is None:
|
1010
|
+
self._sklearn_object = self._create_sklearn_object()
|
1011
|
+
return self._sklearn_object
|
1012
|
+
|
1013
|
+
def to_xgboost(self) -> Any:
|
1014
|
+
raise exceptions.SnowflakeMLException(
|
1015
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1016
|
+
original_exception=AttributeError(
|
1017
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1018
|
+
"to_xgboost()",
|
1019
|
+
"to_sklearn()"
|
1020
|
+
)
|
1021
|
+
),
|
1022
|
+
)
|
1023
|
+
|
1024
|
+
def to_lightgbm(self) -> Any:
|
1025
|
+
raise exceptions.SnowflakeMLException(
|
1026
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1027
|
+
original_exception=AttributeError(
|
1028
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1029
|
+
"to_lightgbm()",
|
1030
|
+
"to_sklearn()"
|
1031
|
+
)
|
1032
|
+
),
|
1033
|
+
)
|
1034
|
+
|
1035
|
+
def _get_dependencies(self) -> List[str]:
|
1036
|
+
return self._deps
|
1037
|
+
|
935
1038
|
|
936
|
-
def
|
1039
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
937
1040
|
self._model_signature_dict = dict()
|
938
1041
|
|
939
1042
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
940
1043
|
|
941
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1044
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
942
1045
|
outputs: List[BaseFeatureSpec] = []
|
943
1046
|
if hasattr(self, "predict"):
|
944
1047
|
# keep mypy happy
|
945
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1048
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
946
1049
|
# For classifier, the type of predict is the same as the type of label
|
947
|
-
if self._sklearn_object._estimator_type ==
|
948
|
-
|
1050
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1051
|
+
# label columns is the desired type for output
|
949
1052
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
950
1053
|
# rename the output columns
|
951
1054
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
952
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
953
|
-
|
954
|
-
|
1055
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1056
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1057
|
+
)
|
955
1058
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
956
1059
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
957
|
-
# Clusterer returns int64 cluster labels.
|
1060
|
+
# Clusterer returns int64 cluster labels.
|
958
1061
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
959
1062
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
960
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
961
|
-
|
962
|
-
|
963
|
-
|
1063
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1064
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1065
|
+
)
|
1066
|
+
|
964
1067
|
# For regressor, the type of predict is float64
|
965
|
-
elif self._sklearn_object._estimator_type ==
|
1068
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
966
1069
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
967
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
968
|
-
|
969
|
-
|
970
|
-
|
1070
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1071
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1072
|
+
)
|
1073
|
+
|
971
1074
|
for prob_func in PROB_FUNCTIONS:
|
972
1075
|
if hasattr(self, prob_func):
|
973
1076
|
output_cols_prefix: str = f"{prob_func}_"
|
974
1077
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
975
1078
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
976
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
977
|
-
|
978
|
-
|
1079
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1080
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1081
|
+
)
|
979
1082
|
|
980
1083
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
981
1084
|
items = list(self._model_signature_dict.items())
|
@@ -988,10 +1091,10 @@ class ComplementNB(BaseTransformer):
|
|
988
1091
|
"""Returns model signature of current class.
|
989
1092
|
|
990
1093
|
Raises:
|
991
|
-
|
1094
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
992
1095
|
|
993
1096
|
Returns:
|
994
|
-
Dict
|
1097
|
+
Dict with each method and its input output signature
|
995
1098
|
"""
|
996
1099
|
if self._model_signature_dict is None:
|
997
1100
|
raise exceptions.SnowflakeMLException(
|
@@ -999,35 +1102,3 @@ class ComplementNB(BaseTransformer):
|
|
999
1102
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1000
1103
|
)
|
1001
1104
|
return self._model_signature_dict
|
1002
|
-
|
1003
|
-
def to_sklearn(self) -> Any:
|
1004
|
-
"""Get sklearn.naive_bayes.ComplementNB object.
|
1005
|
-
"""
|
1006
|
-
if self._sklearn_object is None:
|
1007
|
-
self._sklearn_object = self._create_sklearn_object()
|
1008
|
-
return self._sklearn_object
|
1009
|
-
|
1010
|
-
def to_xgboost(self) -> Any:
|
1011
|
-
raise exceptions.SnowflakeMLException(
|
1012
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1013
|
-
original_exception=AttributeError(
|
1014
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1015
|
-
"to_xgboost()",
|
1016
|
-
"to_sklearn()"
|
1017
|
-
)
|
1018
|
-
),
|
1019
|
-
)
|
1020
|
-
|
1021
|
-
def to_lightgbm(self) -> Any:
|
1022
|
-
raise exceptions.SnowflakeMLException(
|
1023
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1024
|
-
original_exception=AttributeError(
|
1025
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1026
|
-
"to_lightgbm()",
|
1027
|
-
"to_sklearn()"
|
1028
|
-
)
|
1029
|
-
),
|
1030
|
-
)
|
1031
|
-
|
1032
|
-
def _get_dependencies(self) -> List[str]:
|
1033
|
-
return self._deps
|