snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("skl
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class OneVsRestClassifier(BaseTransformer):
|
71
64
|
r"""One-vs-the-rest (OvR) multiclass strategy
|
72
65
|
For more details on this class, see [sklearn.multiclass.OneVsRestClassifier]
|
@@ -219,12 +212,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
219
212
|
)
|
220
213
|
return selected_cols
|
221
214
|
|
222
|
-
|
223
|
-
project=_PROJECT,
|
224
|
-
subproject=_SUBPROJECT,
|
225
|
-
custom_tags=dict([("autogen", True)]),
|
226
|
-
)
|
227
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OneVsRestClassifier":
|
215
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OneVsRestClassifier":
|
228
216
|
"""Fit underlying estimators
|
229
217
|
For more details on this function, see [sklearn.multiclass.OneVsRestClassifier.fit]
|
230
218
|
(https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html#sklearn.multiclass.OneVsRestClassifier.fit)
|
@@ -251,12 +239,14 @@ class OneVsRestClassifier(BaseTransformer):
|
|
251
239
|
|
252
240
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
253
241
|
|
254
|
-
|
242
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
255
243
|
if SNOWML_SPROC_ENV in os.environ:
|
256
244
|
statement_params = telemetry.get_function_usage_statement_params(
|
257
245
|
project=_PROJECT,
|
258
246
|
subproject=_SUBPROJECT,
|
259
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
247
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
248
|
+
inspect.currentframe(), OneVsRestClassifier.__class__.__name__
|
249
|
+
),
|
260
250
|
api_calls=[Session.call],
|
261
251
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
262
252
|
)
|
@@ -277,27 +267,24 @@ class OneVsRestClassifier(BaseTransformer):
|
|
277
267
|
)
|
278
268
|
self._sklearn_object = model_trainer.train()
|
279
269
|
self._is_fitted = True
|
280
|
-
self.
|
270
|
+
self._generate_model_signatures(dataset)
|
281
271
|
return self
|
282
272
|
|
283
273
|
def _batch_inference_validate_snowpark(
|
284
274
|
self,
|
285
275
|
dataset: DataFrame,
|
286
276
|
inference_method: str,
|
287
|
-
) ->
|
288
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
289
|
-
return the available package that exists in the snowflake anaconda channel
|
277
|
+
) -> None:
|
278
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
290
279
|
|
291
280
|
Args:
|
292
281
|
dataset: snowpark dataframe
|
293
282
|
inference_method: the inference method such as predict, score...
|
294
|
-
|
283
|
+
|
295
284
|
Raises:
|
296
285
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
297
286
|
SnowflakeMLException: If the session is None, raise error
|
298
287
|
|
299
|
-
Returns:
|
300
|
-
A list of available package that exists in the snowflake anaconda channel
|
301
288
|
"""
|
302
289
|
if not self._is_fitted:
|
303
290
|
raise exceptions.SnowflakeMLException(
|
@@ -315,9 +302,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
315
302
|
"Session must not specified for snowpark dataset."
|
316
303
|
),
|
317
304
|
)
|
318
|
-
|
319
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
320
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
305
|
+
|
321
306
|
|
322
307
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
323
308
|
@telemetry.send_api_usage_telemetry(
|
@@ -353,7 +338,9 @@ class OneVsRestClassifier(BaseTransformer):
|
|
353
338
|
# when it is classifier, infer the datatype from label columns
|
354
339
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
355
340
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
356
|
-
label_cols_signatures = [
|
341
|
+
label_cols_signatures = [
|
342
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
343
|
+
]
|
357
344
|
if len(label_cols_signatures) == 0:
|
358
345
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
359
346
|
raise exceptions.SnowflakeMLException(
|
@@ -361,25 +348,23 @@ class OneVsRestClassifier(BaseTransformer):
|
|
361
348
|
original_exception=ValueError(error_str),
|
362
349
|
)
|
363
350
|
|
364
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
365
|
-
label_cols_signatures[0].as_snowpark_type()
|
366
|
-
)
|
351
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
367
352
|
|
368
|
-
self.
|
369
|
-
|
353
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
354
|
+
self._deps = self._get_dependencies()
|
355
|
+
assert isinstance(
|
356
|
+
dataset._session, Session
|
357
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
370
358
|
|
371
359
|
transform_kwargs = dict(
|
372
|
-
session
|
373
|
-
dependencies
|
374
|
-
drop_input_cols
|
375
|
-
expected_output_cols_type
|
360
|
+
session=dataset._session,
|
361
|
+
dependencies=self._deps,
|
362
|
+
drop_input_cols=self._drop_input_cols,
|
363
|
+
expected_output_cols_type=expected_type_inferred,
|
376
364
|
)
|
377
365
|
|
378
366
|
elif isinstance(dataset, pd.DataFrame):
|
379
|
-
transform_kwargs = dict(
|
380
|
-
snowpark_input_cols = self._snowpark_cols,
|
381
|
-
drop_input_cols = self._drop_input_cols
|
382
|
-
)
|
367
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
383
368
|
|
384
369
|
transform_handlers = ModelTransformerBuilder.build(
|
385
370
|
dataset=dataset,
|
@@ -419,7 +404,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
419
404
|
Transformed dataset.
|
420
405
|
"""
|
421
406
|
super()._check_dataset_type(dataset)
|
422
|
-
inference_method="transform"
|
407
|
+
inference_method = "transform"
|
423
408
|
|
424
409
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
425
410
|
# are specific to the type of dataset used.
|
@@ -449,24 +434,19 @@ class OneVsRestClassifier(BaseTransformer):
|
|
449
434
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
450
435
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
451
436
|
|
452
|
-
self.
|
453
|
-
|
454
|
-
inference_method=inference_method,
|
455
|
-
)
|
437
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
438
|
+
self._deps = self._get_dependencies()
|
456
439
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
457
440
|
|
458
441
|
transform_kwargs = dict(
|
459
|
-
session
|
460
|
-
dependencies
|
461
|
-
drop_input_cols
|
462
|
-
expected_output_cols_type
|
442
|
+
session=dataset._session,
|
443
|
+
dependencies=self._deps,
|
444
|
+
drop_input_cols=self._drop_input_cols,
|
445
|
+
expected_output_cols_type=expected_dtype,
|
463
446
|
)
|
464
447
|
|
465
448
|
elif isinstance(dataset, pd.DataFrame):
|
466
|
-
transform_kwargs = dict(
|
467
|
-
snowpark_input_cols = self._snowpark_cols,
|
468
|
-
drop_input_cols = self._drop_input_cols
|
469
|
-
)
|
449
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
470
450
|
|
471
451
|
transform_handlers = ModelTransformerBuilder.build(
|
472
452
|
dataset=dataset,
|
@@ -485,7 +465,11 @@ class OneVsRestClassifier(BaseTransformer):
|
|
485
465
|
return output_df
|
486
466
|
|
487
467
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
488
|
-
def fit_predict(
|
468
|
+
def fit_predict(
|
469
|
+
self,
|
470
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
471
|
+
output_cols_prefix: str = "fit_predict_",
|
472
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
489
473
|
""" Method not supported for this class.
|
490
474
|
|
491
475
|
|
@@ -510,22 +494,104 @@ class OneVsRestClassifier(BaseTransformer):
|
|
510
494
|
)
|
511
495
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
512
496
|
drop_input_cols=self._drop_input_cols,
|
513
|
-
expected_output_cols_list=
|
497
|
+
expected_output_cols_list=(
|
498
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
499
|
+
),
|
514
500
|
)
|
515
501
|
self._sklearn_object = fitted_estimator
|
516
502
|
self._is_fitted = True
|
517
503
|
return output_result
|
518
504
|
|
505
|
+
|
506
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
507
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
508
|
+
""" Method not supported for this class.
|
509
|
+
|
519
510
|
|
520
|
-
|
521
|
-
|
522
|
-
|
511
|
+
Raises:
|
512
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
513
|
+
|
514
|
+
Args:
|
515
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
516
|
+
Snowpark or Pandas DataFrame.
|
517
|
+
output_cols_prefix: Prefix for the response columns
|
523
518
|
Returns:
|
524
519
|
Transformed dataset.
|
525
520
|
"""
|
526
|
-
self.
|
527
|
-
|
528
|
-
|
521
|
+
self._infer_input_output_cols(dataset)
|
522
|
+
super()._check_dataset_type(dataset)
|
523
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
524
|
+
estimator=self._sklearn_object,
|
525
|
+
dataset=dataset,
|
526
|
+
input_cols=self.input_cols,
|
527
|
+
label_cols=self.label_cols,
|
528
|
+
sample_weight_col=self.sample_weight_col,
|
529
|
+
autogenerated=self._autogenerated,
|
530
|
+
subproject=_SUBPROJECT,
|
531
|
+
)
|
532
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
533
|
+
drop_input_cols=self._drop_input_cols,
|
534
|
+
expected_output_cols_list=self.output_cols,
|
535
|
+
)
|
536
|
+
self._sklearn_object = fitted_estimator
|
537
|
+
self._is_fitted = True
|
538
|
+
return output_result
|
539
|
+
|
540
|
+
|
541
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
542
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
543
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
544
|
+
"""
|
545
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
546
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
547
|
+
if output_cols:
|
548
|
+
output_cols = [
|
549
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
550
|
+
for c in output_cols
|
551
|
+
]
|
552
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
553
|
+
output_cols = [output_cols_prefix]
|
554
|
+
elif self._sklearn_object is not None:
|
555
|
+
classes = self._sklearn_object.classes_
|
556
|
+
if isinstance(classes, numpy.ndarray):
|
557
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
558
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
559
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
560
|
+
output_cols = []
|
561
|
+
for i, cl in enumerate(classes):
|
562
|
+
# For binary classification, there is only one output column for each class
|
563
|
+
# ndarray as the two classes are complementary.
|
564
|
+
if len(cl) == 2:
|
565
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
566
|
+
else:
|
567
|
+
output_cols.extend([
|
568
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
569
|
+
])
|
570
|
+
else:
|
571
|
+
output_cols = []
|
572
|
+
|
573
|
+
# Make sure column names are valid snowflake identifiers.
|
574
|
+
assert output_cols is not None # Make MyPy happy
|
575
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
576
|
+
|
577
|
+
return rv
|
578
|
+
|
579
|
+
def _align_expected_output_names(
|
580
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
581
|
+
) -> List[str]:
|
582
|
+
# in case the inferred output column names dimension is different
|
583
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
584
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
585
|
+
output_df_columns = list(output_df_pd.columns)
|
586
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
587
|
+
if self.sample_weight_col:
|
588
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
589
|
+
# if the dimension of inferred output column names is correct; use it
|
590
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
591
|
+
return expected_output_cols_list
|
592
|
+
# otherwise, use the sklearn estimator's output
|
593
|
+
else:
|
594
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
529
595
|
|
530
596
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
531
597
|
@telemetry.send_api_usage_telemetry(
|
@@ -559,24 +625,26 @@ class OneVsRestClassifier(BaseTransformer):
|
|
559
625
|
# are specific to the type of dataset used.
|
560
626
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
561
627
|
|
628
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
629
|
+
|
562
630
|
if isinstance(dataset, DataFrame):
|
563
|
-
self.
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
631
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
632
|
+
self._deps = self._get_dependencies()
|
633
|
+
assert isinstance(
|
634
|
+
dataset._session, Session
|
635
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
568
636
|
transform_kwargs = dict(
|
569
637
|
session=dataset._session,
|
570
638
|
dependencies=self._deps,
|
571
|
-
drop_input_cols
|
639
|
+
drop_input_cols=self._drop_input_cols,
|
572
640
|
expected_output_cols_type="float",
|
573
641
|
)
|
642
|
+
expected_output_cols = self._align_expected_output_names(
|
643
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
644
|
+
)
|
574
645
|
|
575
646
|
elif isinstance(dataset, pd.DataFrame):
|
576
|
-
transform_kwargs = dict(
|
577
|
-
snowpark_input_cols = self._snowpark_cols,
|
578
|
-
drop_input_cols = self._drop_input_cols
|
579
|
-
)
|
647
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
580
648
|
|
581
649
|
transform_handlers = ModelTransformerBuilder.build(
|
582
650
|
dataset=dataset,
|
@@ -588,7 +656,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
588
656
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
589
657
|
inference_method=inference_method,
|
590
658
|
input_cols=self.input_cols,
|
591
|
-
expected_output_cols=
|
659
|
+
expected_output_cols=expected_output_cols,
|
592
660
|
**transform_kwargs
|
593
661
|
)
|
594
662
|
return output_df
|
@@ -620,29 +688,30 @@ class OneVsRestClassifier(BaseTransformer):
|
|
620
688
|
Output dataset with log probability of the sample for each class in the model.
|
621
689
|
"""
|
622
690
|
super()._check_dataset_type(dataset)
|
623
|
-
inference_method="predict_log_proba"
|
691
|
+
inference_method = "predict_log_proba"
|
692
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
624
693
|
|
625
694
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
626
695
|
# are specific to the type of dataset used.
|
627
696
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
628
697
|
|
629
698
|
if isinstance(dataset, DataFrame):
|
630
|
-
self.
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
699
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
700
|
+
self._deps = self._get_dependencies()
|
701
|
+
assert isinstance(
|
702
|
+
dataset._session, Session
|
703
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
635
704
|
transform_kwargs = dict(
|
636
705
|
session=dataset._session,
|
637
706
|
dependencies=self._deps,
|
638
|
-
drop_input_cols
|
707
|
+
drop_input_cols=self._drop_input_cols,
|
639
708
|
expected_output_cols_type="float",
|
640
709
|
)
|
710
|
+
expected_output_cols = self._align_expected_output_names(
|
711
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
712
|
+
)
|
641
713
|
elif isinstance(dataset, pd.DataFrame):
|
642
|
-
transform_kwargs = dict(
|
643
|
-
snowpark_input_cols = self._snowpark_cols,
|
644
|
-
drop_input_cols = self._drop_input_cols
|
645
|
-
)
|
714
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
646
715
|
|
647
716
|
transform_handlers = ModelTransformerBuilder.build(
|
648
717
|
dataset=dataset,
|
@@ -655,7 +724,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
655
724
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
656
725
|
inference_method=inference_method,
|
657
726
|
input_cols=self.input_cols,
|
658
|
-
expected_output_cols=
|
727
|
+
expected_output_cols=expected_output_cols,
|
659
728
|
**transform_kwargs
|
660
729
|
)
|
661
730
|
return output_df
|
@@ -683,30 +752,32 @@ class OneVsRestClassifier(BaseTransformer):
|
|
683
752
|
Output dataset with results of the decision function for the samples in input dataset.
|
684
753
|
"""
|
685
754
|
super()._check_dataset_type(dataset)
|
686
|
-
inference_method="decision_function"
|
755
|
+
inference_method = "decision_function"
|
687
756
|
|
688
757
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
689
758
|
# are specific to the type of dataset used.
|
690
759
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
691
760
|
|
761
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
762
|
+
|
692
763
|
if isinstance(dataset, DataFrame):
|
693
|
-
self.
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
766
|
+
assert isinstance(
|
767
|
+
dataset._session, Session
|
768
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
698
769
|
transform_kwargs = dict(
|
699
770
|
session=dataset._session,
|
700
771
|
dependencies=self._deps,
|
701
|
-
drop_input_cols
|
772
|
+
drop_input_cols=self._drop_input_cols,
|
702
773
|
expected_output_cols_type="float",
|
703
774
|
)
|
775
|
+
expected_output_cols = self._align_expected_output_names(
|
776
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
777
|
+
)
|
704
778
|
|
705
779
|
elif isinstance(dataset, pd.DataFrame):
|
706
|
-
transform_kwargs = dict(
|
707
|
-
snowpark_input_cols = self._snowpark_cols,
|
708
|
-
drop_input_cols = self._drop_input_cols
|
709
|
-
)
|
780
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
710
781
|
|
711
782
|
transform_handlers = ModelTransformerBuilder.build(
|
712
783
|
dataset=dataset,
|
@@ -719,7 +790,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
719
790
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
720
791
|
inference_method=inference_method,
|
721
792
|
input_cols=self.input_cols,
|
722
|
-
expected_output_cols=
|
793
|
+
expected_output_cols=expected_output_cols,
|
723
794
|
**transform_kwargs
|
724
795
|
)
|
725
796
|
return output_df
|
@@ -748,17 +819,17 @@ class OneVsRestClassifier(BaseTransformer):
|
|
748
819
|
Output dataset with probability of the sample for each class in the model.
|
749
820
|
"""
|
750
821
|
super()._check_dataset_type(dataset)
|
751
|
-
inference_method="score_samples"
|
822
|
+
inference_method = "score_samples"
|
752
823
|
|
753
824
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
754
825
|
# are specific to the type of dataset used.
|
755
826
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
756
827
|
|
828
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
829
|
+
|
757
830
|
if isinstance(dataset, DataFrame):
|
758
|
-
self.
|
759
|
-
|
760
|
-
inference_method=inference_method,
|
761
|
-
)
|
831
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
832
|
+
self._deps = self._get_dependencies()
|
762
833
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
763
834
|
transform_kwargs = dict(
|
764
835
|
session=dataset._session,
|
@@ -766,6 +837,9 @@ class OneVsRestClassifier(BaseTransformer):
|
|
766
837
|
drop_input_cols = self._drop_input_cols,
|
767
838
|
expected_output_cols_type="float",
|
768
839
|
)
|
840
|
+
expected_output_cols = self._align_expected_output_names(
|
841
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
842
|
+
)
|
769
843
|
|
770
844
|
elif isinstance(dataset, pd.DataFrame):
|
771
845
|
transform_kwargs = dict(
|
@@ -784,7 +858,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
784
858
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
785
859
|
inference_method=inference_method,
|
786
860
|
input_cols=self.input_cols,
|
787
|
-
expected_output_cols=
|
861
|
+
expected_output_cols=expected_output_cols,
|
788
862
|
**transform_kwargs
|
789
863
|
)
|
790
864
|
return output_df
|
@@ -819,17 +893,15 @@ class OneVsRestClassifier(BaseTransformer):
|
|
819
893
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
820
894
|
|
821
895
|
if isinstance(dataset, DataFrame):
|
822
|
-
self.
|
823
|
-
|
824
|
-
inference_method="score",
|
825
|
-
)
|
896
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
897
|
+
self._deps = self._get_dependencies()
|
826
898
|
selected_cols = self._get_active_columns()
|
827
899
|
if len(selected_cols) > 0:
|
828
900
|
dataset = dataset.select(selected_cols)
|
829
901
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
830
902
|
transform_kwargs = dict(
|
831
903
|
session=dataset._session,
|
832
|
-
dependencies=
|
904
|
+
dependencies=self._deps,
|
833
905
|
score_sproc_imports=['sklearn'],
|
834
906
|
)
|
835
907
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -894,11 +966,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
894
966
|
|
895
967
|
if isinstance(dataset, DataFrame):
|
896
968
|
|
897
|
-
self.
|
898
|
-
|
899
|
-
inference_method=inference_method,
|
900
|
-
|
901
|
-
)
|
969
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
970
|
+
self._deps = self._get_dependencies()
|
902
971
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
903
972
|
transform_kwargs = dict(
|
904
973
|
session = dataset._session,
|
@@ -931,50 +1000,84 @@ class OneVsRestClassifier(BaseTransformer):
|
|
931
1000
|
)
|
932
1001
|
return output_df
|
933
1002
|
|
1003
|
+
|
1004
|
+
|
1005
|
+
def to_sklearn(self) -> Any:
|
1006
|
+
"""Get sklearn.multiclass.OneVsRestClassifier object.
|
1007
|
+
"""
|
1008
|
+
if self._sklearn_object is None:
|
1009
|
+
self._sklearn_object = self._create_sklearn_object()
|
1010
|
+
return self._sklearn_object
|
1011
|
+
|
1012
|
+
def to_xgboost(self) -> Any:
|
1013
|
+
raise exceptions.SnowflakeMLException(
|
1014
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1015
|
+
original_exception=AttributeError(
|
1016
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1017
|
+
"to_xgboost()",
|
1018
|
+
"to_sklearn()"
|
1019
|
+
)
|
1020
|
+
),
|
1021
|
+
)
|
1022
|
+
|
1023
|
+
def to_lightgbm(self) -> Any:
|
1024
|
+
raise exceptions.SnowflakeMLException(
|
1025
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1026
|
+
original_exception=AttributeError(
|
1027
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1028
|
+
"to_lightgbm()",
|
1029
|
+
"to_sklearn()"
|
1030
|
+
)
|
1031
|
+
),
|
1032
|
+
)
|
1033
|
+
|
1034
|
+
def _get_dependencies(self) -> List[str]:
|
1035
|
+
return self._deps
|
1036
|
+
|
934
1037
|
|
935
|
-
def
|
1038
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
936
1039
|
self._model_signature_dict = dict()
|
937
1040
|
|
938
1041
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
939
1042
|
|
940
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1043
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
941
1044
|
outputs: List[BaseFeatureSpec] = []
|
942
1045
|
if hasattr(self, "predict"):
|
943
1046
|
# keep mypy happy
|
944
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1047
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
945
1048
|
# For classifier, the type of predict is the same as the type of label
|
946
|
-
if self._sklearn_object._estimator_type ==
|
947
|
-
|
1049
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1050
|
+
# label columns is the desired type for output
|
948
1051
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
949
1052
|
# rename the output columns
|
950
1053
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
951
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
952
|
-
|
953
|
-
|
1054
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1055
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1056
|
+
)
|
954
1057
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
955
1058
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
956
|
-
# Clusterer returns int64 cluster labels.
|
1059
|
+
# Clusterer returns int64 cluster labels.
|
957
1060
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
958
1061
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
959
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
960
|
-
|
961
|
-
|
962
|
-
|
1062
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1063
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1064
|
+
)
|
1065
|
+
|
963
1066
|
# For regressor, the type of predict is float64
|
964
|
-
elif self._sklearn_object._estimator_type ==
|
1067
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
965
1068
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
966
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
967
|
-
|
968
|
-
|
969
|
-
|
1069
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1070
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1071
|
+
)
|
1072
|
+
|
970
1073
|
for prob_func in PROB_FUNCTIONS:
|
971
1074
|
if hasattr(self, prob_func):
|
972
1075
|
output_cols_prefix: str = f"{prob_func}_"
|
973
1076
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
974
1077
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
975
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
976
|
-
|
977
|
-
|
1078
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1079
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1080
|
+
)
|
978
1081
|
|
979
1082
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
980
1083
|
items = list(self._model_signature_dict.items())
|
@@ -987,10 +1090,10 @@ class OneVsRestClassifier(BaseTransformer):
|
|
987
1090
|
"""Returns model signature of current class.
|
988
1091
|
|
989
1092
|
Raises:
|
990
|
-
|
1093
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
991
1094
|
|
992
1095
|
Returns:
|
993
|
-
Dict
|
1096
|
+
Dict with each method and its input output signature
|
994
1097
|
"""
|
995
1098
|
if self._model_signature_dict is None:
|
996
1099
|
raise exceptions.SnowflakeMLException(
|
@@ -998,35 +1101,3 @@ class OneVsRestClassifier(BaseTransformer):
|
|
998
1101
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
999
1102
|
)
|
1000
1103
|
return self._model_signature_dict
|
1001
|
-
|
1002
|
-
def to_sklearn(self) -> Any:
|
1003
|
-
"""Get sklearn.multiclass.OneVsRestClassifier object.
|
1004
|
-
"""
|
1005
|
-
if self._sklearn_object is None:
|
1006
|
-
self._sklearn_object = self._create_sklearn_object()
|
1007
|
-
return self._sklearn_object
|
1008
|
-
|
1009
|
-
def to_xgboost(self) -> Any:
|
1010
|
-
raise exceptions.SnowflakeMLException(
|
1011
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1012
|
-
original_exception=AttributeError(
|
1013
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1014
|
-
"to_xgboost()",
|
1015
|
-
"to_sklearn()"
|
1016
|
-
)
|
1017
|
-
),
|
1018
|
-
)
|
1019
|
-
|
1020
|
-
def to_lightgbm(self) -> Any:
|
1021
|
-
raise exceptions.SnowflakeMLException(
|
1022
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1023
|
-
original_exception=AttributeError(
|
1024
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1025
|
-
"to_lightgbm()",
|
1026
|
-
"to_sklearn()"
|
1027
|
-
)
|
1028
|
-
),
|
1029
|
-
)
|
1030
|
-
|
1031
|
-
def _get_dependencies(self) -> List[str]:
|
1032
|
-
return self._deps
|