snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("skl
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class OneVsOneClassifier(BaseTransformer):
|
71
64
|
r"""One-vs-one multiclass strategy
|
72
65
|
For more details on this class, see [sklearn.multiclass.OneVsOneClassifier]
|
@@ -210,12 +203,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
210
203
|
)
|
211
204
|
return selected_cols
|
212
205
|
|
213
|
-
|
214
|
-
project=_PROJECT,
|
215
|
-
subproject=_SUBPROJECT,
|
216
|
-
custom_tags=dict([("autogen", True)]),
|
217
|
-
)
|
218
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OneVsOneClassifier":
|
206
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OneVsOneClassifier":
|
219
207
|
"""Fit underlying estimators
|
220
208
|
For more details on this function, see [sklearn.multiclass.OneVsOneClassifier.fit]
|
221
209
|
(https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsOneClassifier.html#sklearn.multiclass.OneVsOneClassifier.fit)
|
@@ -242,12 +230,14 @@ class OneVsOneClassifier(BaseTransformer):
|
|
242
230
|
|
243
231
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
244
232
|
|
245
|
-
|
233
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
246
234
|
if SNOWML_SPROC_ENV in os.environ:
|
247
235
|
statement_params = telemetry.get_function_usage_statement_params(
|
248
236
|
project=_PROJECT,
|
249
237
|
subproject=_SUBPROJECT,
|
250
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
238
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
239
|
+
inspect.currentframe(), OneVsOneClassifier.__class__.__name__
|
240
|
+
),
|
251
241
|
api_calls=[Session.call],
|
252
242
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
253
243
|
)
|
@@ -268,27 +258,24 @@ class OneVsOneClassifier(BaseTransformer):
|
|
268
258
|
)
|
269
259
|
self._sklearn_object = model_trainer.train()
|
270
260
|
self._is_fitted = True
|
271
|
-
self.
|
261
|
+
self._generate_model_signatures(dataset)
|
272
262
|
return self
|
273
263
|
|
274
264
|
def _batch_inference_validate_snowpark(
|
275
265
|
self,
|
276
266
|
dataset: DataFrame,
|
277
267
|
inference_method: str,
|
278
|
-
) ->
|
279
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
280
|
-
return the available package that exists in the snowflake anaconda channel
|
268
|
+
) -> None:
|
269
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
281
270
|
|
282
271
|
Args:
|
283
272
|
dataset: snowpark dataframe
|
284
273
|
inference_method: the inference method such as predict, score...
|
285
|
-
|
274
|
+
|
286
275
|
Raises:
|
287
276
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
288
277
|
SnowflakeMLException: If the session is None, raise error
|
289
278
|
|
290
|
-
Returns:
|
291
|
-
A list of available package that exists in the snowflake anaconda channel
|
292
279
|
"""
|
293
280
|
if not self._is_fitted:
|
294
281
|
raise exceptions.SnowflakeMLException(
|
@@ -306,9 +293,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
306
293
|
"Session must not specified for snowpark dataset."
|
307
294
|
),
|
308
295
|
)
|
309
|
-
|
310
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
311
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
296
|
+
|
312
297
|
|
313
298
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
314
299
|
@telemetry.send_api_usage_telemetry(
|
@@ -344,7 +329,9 @@ class OneVsOneClassifier(BaseTransformer):
|
|
344
329
|
# when it is classifier, infer the datatype from label columns
|
345
330
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
346
331
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
347
|
-
label_cols_signatures = [
|
332
|
+
label_cols_signatures = [
|
333
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
334
|
+
]
|
348
335
|
if len(label_cols_signatures) == 0:
|
349
336
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
350
337
|
raise exceptions.SnowflakeMLException(
|
@@ -352,25 +339,23 @@ class OneVsOneClassifier(BaseTransformer):
|
|
352
339
|
original_exception=ValueError(error_str),
|
353
340
|
)
|
354
341
|
|
355
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
356
|
-
label_cols_signatures[0].as_snowpark_type()
|
357
|
-
)
|
342
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
358
343
|
|
359
|
-
self.
|
360
|
-
|
344
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
345
|
+
self._deps = self._get_dependencies()
|
346
|
+
assert isinstance(
|
347
|
+
dataset._session, Session
|
348
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
361
349
|
|
362
350
|
transform_kwargs = dict(
|
363
|
-
session
|
364
|
-
dependencies
|
365
|
-
drop_input_cols
|
366
|
-
expected_output_cols_type
|
351
|
+
session=dataset._session,
|
352
|
+
dependencies=self._deps,
|
353
|
+
drop_input_cols=self._drop_input_cols,
|
354
|
+
expected_output_cols_type=expected_type_inferred,
|
367
355
|
)
|
368
356
|
|
369
357
|
elif isinstance(dataset, pd.DataFrame):
|
370
|
-
transform_kwargs = dict(
|
371
|
-
snowpark_input_cols = self._snowpark_cols,
|
372
|
-
drop_input_cols = self._drop_input_cols
|
373
|
-
)
|
358
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
374
359
|
|
375
360
|
transform_handlers = ModelTransformerBuilder.build(
|
376
361
|
dataset=dataset,
|
@@ -410,7 +395,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
410
395
|
Transformed dataset.
|
411
396
|
"""
|
412
397
|
super()._check_dataset_type(dataset)
|
413
|
-
inference_method="transform"
|
398
|
+
inference_method = "transform"
|
414
399
|
|
415
400
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
416
401
|
# are specific to the type of dataset used.
|
@@ -440,24 +425,19 @@ class OneVsOneClassifier(BaseTransformer):
|
|
440
425
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
441
426
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
442
427
|
|
443
|
-
self.
|
444
|
-
|
445
|
-
inference_method=inference_method,
|
446
|
-
)
|
428
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
429
|
+
self._deps = self._get_dependencies()
|
447
430
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
448
431
|
|
449
432
|
transform_kwargs = dict(
|
450
|
-
session
|
451
|
-
dependencies
|
452
|
-
drop_input_cols
|
453
|
-
expected_output_cols_type
|
433
|
+
session=dataset._session,
|
434
|
+
dependencies=self._deps,
|
435
|
+
drop_input_cols=self._drop_input_cols,
|
436
|
+
expected_output_cols_type=expected_dtype,
|
454
437
|
)
|
455
438
|
|
456
439
|
elif isinstance(dataset, pd.DataFrame):
|
457
|
-
transform_kwargs = dict(
|
458
|
-
snowpark_input_cols = self._snowpark_cols,
|
459
|
-
drop_input_cols = self._drop_input_cols
|
460
|
-
)
|
440
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
461
441
|
|
462
442
|
transform_handlers = ModelTransformerBuilder.build(
|
463
443
|
dataset=dataset,
|
@@ -476,7 +456,11 @@ class OneVsOneClassifier(BaseTransformer):
|
|
476
456
|
return output_df
|
477
457
|
|
478
458
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
479
|
-
def fit_predict(
|
459
|
+
def fit_predict(
|
460
|
+
self,
|
461
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
462
|
+
output_cols_prefix: str = "fit_predict_",
|
463
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
480
464
|
""" Method not supported for this class.
|
481
465
|
|
482
466
|
|
@@ -501,22 +485,104 @@ class OneVsOneClassifier(BaseTransformer):
|
|
501
485
|
)
|
502
486
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
503
487
|
drop_input_cols=self._drop_input_cols,
|
504
|
-
expected_output_cols_list=
|
488
|
+
expected_output_cols_list=(
|
489
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
490
|
+
),
|
505
491
|
)
|
506
492
|
self._sklearn_object = fitted_estimator
|
507
493
|
self._is_fitted = True
|
508
494
|
return output_result
|
509
495
|
|
496
|
+
|
497
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
498
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
499
|
+
""" Method not supported for this class.
|
500
|
+
|
510
501
|
|
511
|
-
|
512
|
-
|
513
|
-
|
502
|
+
Raises:
|
503
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
504
|
+
|
505
|
+
Args:
|
506
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
507
|
+
Snowpark or Pandas DataFrame.
|
508
|
+
output_cols_prefix: Prefix for the response columns
|
514
509
|
Returns:
|
515
510
|
Transformed dataset.
|
516
511
|
"""
|
517
|
-
self.
|
518
|
-
|
519
|
-
|
512
|
+
self._infer_input_output_cols(dataset)
|
513
|
+
super()._check_dataset_type(dataset)
|
514
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
515
|
+
estimator=self._sklearn_object,
|
516
|
+
dataset=dataset,
|
517
|
+
input_cols=self.input_cols,
|
518
|
+
label_cols=self.label_cols,
|
519
|
+
sample_weight_col=self.sample_weight_col,
|
520
|
+
autogenerated=self._autogenerated,
|
521
|
+
subproject=_SUBPROJECT,
|
522
|
+
)
|
523
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
524
|
+
drop_input_cols=self._drop_input_cols,
|
525
|
+
expected_output_cols_list=self.output_cols,
|
526
|
+
)
|
527
|
+
self._sklearn_object = fitted_estimator
|
528
|
+
self._is_fitted = True
|
529
|
+
return output_result
|
530
|
+
|
531
|
+
|
532
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
533
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
534
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
535
|
+
"""
|
536
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
537
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
538
|
+
if output_cols:
|
539
|
+
output_cols = [
|
540
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
541
|
+
for c in output_cols
|
542
|
+
]
|
543
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
544
|
+
output_cols = [output_cols_prefix]
|
545
|
+
elif self._sklearn_object is not None:
|
546
|
+
classes = self._sklearn_object.classes_
|
547
|
+
if isinstance(classes, numpy.ndarray):
|
548
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
549
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
550
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
551
|
+
output_cols = []
|
552
|
+
for i, cl in enumerate(classes):
|
553
|
+
# For binary classification, there is only one output column for each class
|
554
|
+
# ndarray as the two classes are complementary.
|
555
|
+
if len(cl) == 2:
|
556
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
557
|
+
else:
|
558
|
+
output_cols.extend([
|
559
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
560
|
+
])
|
561
|
+
else:
|
562
|
+
output_cols = []
|
563
|
+
|
564
|
+
# Make sure column names are valid snowflake identifiers.
|
565
|
+
assert output_cols is not None # Make MyPy happy
|
566
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
567
|
+
|
568
|
+
return rv
|
569
|
+
|
570
|
+
def _align_expected_output_names(
|
571
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
572
|
+
) -> List[str]:
|
573
|
+
# in case the inferred output column names dimension is different
|
574
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
575
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
576
|
+
output_df_columns = list(output_df_pd.columns)
|
577
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
578
|
+
if self.sample_weight_col:
|
579
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
580
|
+
# if the dimension of inferred output column names is correct; use it
|
581
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
582
|
+
return expected_output_cols_list
|
583
|
+
# otherwise, use the sklearn estimator's output
|
584
|
+
else:
|
585
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
520
586
|
|
521
587
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
522
588
|
@telemetry.send_api_usage_telemetry(
|
@@ -548,24 +614,26 @@ class OneVsOneClassifier(BaseTransformer):
|
|
548
614
|
# are specific to the type of dataset used.
|
549
615
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
550
616
|
|
617
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
618
|
+
|
551
619
|
if isinstance(dataset, DataFrame):
|
552
|
-
self.
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
620
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
621
|
+
self._deps = self._get_dependencies()
|
622
|
+
assert isinstance(
|
623
|
+
dataset._session, Session
|
624
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
557
625
|
transform_kwargs = dict(
|
558
626
|
session=dataset._session,
|
559
627
|
dependencies=self._deps,
|
560
|
-
drop_input_cols
|
628
|
+
drop_input_cols=self._drop_input_cols,
|
561
629
|
expected_output_cols_type="float",
|
562
630
|
)
|
631
|
+
expected_output_cols = self._align_expected_output_names(
|
632
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
633
|
+
)
|
563
634
|
|
564
635
|
elif isinstance(dataset, pd.DataFrame):
|
565
|
-
transform_kwargs = dict(
|
566
|
-
snowpark_input_cols = self._snowpark_cols,
|
567
|
-
drop_input_cols = self._drop_input_cols
|
568
|
-
)
|
636
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
569
637
|
|
570
638
|
transform_handlers = ModelTransformerBuilder.build(
|
571
639
|
dataset=dataset,
|
@@ -577,7 +645,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
577
645
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
578
646
|
inference_method=inference_method,
|
579
647
|
input_cols=self.input_cols,
|
580
|
-
expected_output_cols=
|
648
|
+
expected_output_cols=expected_output_cols,
|
581
649
|
**transform_kwargs
|
582
650
|
)
|
583
651
|
return output_df
|
@@ -607,29 +675,30 @@ class OneVsOneClassifier(BaseTransformer):
|
|
607
675
|
Output dataset with log probability of the sample for each class in the model.
|
608
676
|
"""
|
609
677
|
super()._check_dataset_type(dataset)
|
610
|
-
inference_method="predict_log_proba"
|
678
|
+
inference_method = "predict_log_proba"
|
679
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
611
680
|
|
612
681
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
613
682
|
# are specific to the type of dataset used.
|
614
683
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
615
684
|
|
616
685
|
if isinstance(dataset, DataFrame):
|
617
|
-
self.
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
686
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
687
|
+
self._deps = self._get_dependencies()
|
688
|
+
assert isinstance(
|
689
|
+
dataset._session, Session
|
690
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
622
691
|
transform_kwargs = dict(
|
623
692
|
session=dataset._session,
|
624
693
|
dependencies=self._deps,
|
625
|
-
drop_input_cols
|
694
|
+
drop_input_cols=self._drop_input_cols,
|
626
695
|
expected_output_cols_type="float",
|
627
696
|
)
|
697
|
+
expected_output_cols = self._align_expected_output_names(
|
698
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
699
|
+
)
|
628
700
|
elif isinstance(dataset, pd.DataFrame):
|
629
|
-
transform_kwargs = dict(
|
630
|
-
snowpark_input_cols = self._snowpark_cols,
|
631
|
-
drop_input_cols = self._drop_input_cols
|
632
|
-
)
|
701
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
633
702
|
|
634
703
|
transform_handlers = ModelTransformerBuilder.build(
|
635
704
|
dataset=dataset,
|
@@ -642,7 +711,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
642
711
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
643
712
|
inference_method=inference_method,
|
644
713
|
input_cols=self.input_cols,
|
645
|
-
expected_output_cols=
|
714
|
+
expected_output_cols=expected_output_cols,
|
646
715
|
**transform_kwargs
|
647
716
|
)
|
648
717
|
return output_df
|
@@ -670,30 +739,32 @@ class OneVsOneClassifier(BaseTransformer):
|
|
670
739
|
Output dataset with results of the decision function for the samples in input dataset.
|
671
740
|
"""
|
672
741
|
super()._check_dataset_type(dataset)
|
673
|
-
inference_method="decision_function"
|
742
|
+
inference_method = "decision_function"
|
674
743
|
|
675
744
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
676
745
|
# are specific to the type of dataset used.
|
677
746
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
678
747
|
|
748
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
749
|
+
|
679
750
|
if isinstance(dataset, DataFrame):
|
680
|
-
self.
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
753
|
+
assert isinstance(
|
754
|
+
dataset._session, Session
|
755
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
685
756
|
transform_kwargs = dict(
|
686
757
|
session=dataset._session,
|
687
758
|
dependencies=self._deps,
|
688
|
-
drop_input_cols
|
759
|
+
drop_input_cols=self._drop_input_cols,
|
689
760
|
expected_output_cols_type="float",
|
690
761
|
)
|
762
|
+
expected_output_cols = self._align_expected_output_names(
|
763
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
764
|
+
)
|
691
765
|
|
692
766
|
elif isinstance(dataset, pd.DataFrame):
|
693
|
-
transform_kwargs = dict(
|
694
|
-
snowpark_input_cols = self._snowpark_cols,
|
695
|
-
drop_input_cols = self._drop_input_cols
|
696
|
-
)
|
767
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
697
768
|
|
698
769
|
transform_handlers = ModelTransformerBuilder.build(
|
699
770
|
dataset=dataset,
|
@@ -706,7 +777,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
706
777
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
707
778
|
inference_method=inference_method,
|
708
779
|
input_cols=self.input_cols,
|
709
|
-
expected_output_cols=
|
780
|
+
expected_output_cols=expected_output_cols,
|
710
781
|
**transform_kwargs
|
711
782
|
)
|
712
783
|
return output_df
|
@@ -735,17 +806,17 @@ class OneVsOneClassifier(BaseTransformer):
|
|
735
806
|
Output dataset with probability of the sample for each class in the model.
|
736
807
|
"""
|
737
808
|
super()._check_dataset_type(dataset)
|
738
|
-
inference_method="score_samples"
|
809
|
+
inference_method = "score_samples"
|
739
810
|
|
740
811
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
741
812
|
# are specific to the type of dataset used.
|
742
813
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
743
814
|
|
815
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
816
|
+
|
744
817
|
if isinstance(dataset, DataFrame):
|
745
|
-
self.
|
746
|
-
|
747
|
-
inference_method=inference_method,
|
748
|
-
)
|
818
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
819
|
+
self._deps = self._get_dependencies()
|
749
820
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
750
821
|
transform_kwargs = dict(
|
751
822
|
session=dataset._session,
|
@@ -753,6 +824,9 @@ class OneVsOneClassifier(BaseTransformer):
|
|
753
824
|
drop_input_cols = self._drop_input_cols,
|
754
825
|
expected_output_cols_type="float",
|
755
826
|
)
|
827
|
+
expected_output_cols = self._align_expected_output_names(
|
828
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
829
|
+
)
|
756
830
|
|
757
831
|
elif isinstance(dataset, pd.DataFrame):
|
758
832
|
transform_kwargs = dict(
|
@@ -771,7 +845,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
771
845
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
772
846
|
inference_method=inference_method,
|
773
847
|
input_cols=self.input_cols,
|
774
|
-
expected_output_cols=
|
848
|
+
expected_output_cols=expected_output_cols,
|
775
849
|
**transform_kwargs
|
776
850
|
)
|
777
851
|
return output_df
|
@@ -806,17 +880,15 @@ class OneVsOneClassifier(BaseTransformer):
|
|
806
880
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
807
881
|
|
808
882
|
if isinstance(dataset, DataFrame):
|
809
|
-
self.
|
810
|
-
|
811
|
-
inference_method="score",
|
812
|
-
)
|
883
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
884
|
+
self._deps = self._get_dependencies()
|
813
885
|
selected_cols = self._get_active_columns()
|
814
886
|
if len(selected_cols) > 0:
|
815
887
|
dataset = dataset.select(selected_cols)
|
816
888
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
817
889
|
transform_kwargs = dict(
|
818
890
|
session=dataset._session,
|
819
|
-
dependencies=
|
891
|
+
dependencies=self._deps,
|
820
892
|
score_sproc_imports=['sklearn'],
|
821
893
|
)
|
822
894
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -881,11 +953,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
881
953
|
|
882
954
|
if isinstance(dataset, DataFrame):
|
883
955
|
|
884
|
-
self.
|
885
|
-
|
886
|
-
inference_method=inference_method,
|
887
|
-
|
888
|
-
)
|
956
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
957
|
+
self._deps = self._get_dependencies()
|
889
958
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
890
959
|
transform_kwargs = dict(
|
891
960
|
session = dataset._session,
|
@@ -918,50 +987,84 @@ class OneVsOneClassifier(BaseTransformer):
|
|
918
987
|
)
|
919
988
|
return output_df
|
920
989
|
|
990
|
+
|
991
|
+
|
992
|
+
def to_sklearn(self) -> Any:
|
993
|
+
"""Get sklearn.multiclass.OneVsOneClassifier object.
|
994
|
+
"""
|
995
|
+
if self._sklearn_object is None:
|
996
|
+
self._sklearn_object = self._create_sklearn_object()
|
997
|
+
return self._sklearn_object
|
998
|
+
|
999
|
+
def to_xgboost(self) -> Any:
|
1000
|
+
raise exceptions.SnowflakeMLException(
|
1001
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1002
|
+
original_exception=AttributeError(
|
1003
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1004
|
+
"to_xgboost()",
|
1005
|
+
"to_sklearn()"
|
1006
|
+
)
|
1007
|
+
),
|
1008
|
+
)
|
1009
|
+
|
1010
|
+
def to_lightgbm(self) -> Any:
|
1011
|
+
raise exceptions.SnowflakeMLException(
|
1012
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1013
|
+
original_exception=AttributeError(
|
1014
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1015
|
+
"to_lightgbm()",
|
1016
|
+
"to_sklearn()"
|
1017
|
+
)
|
1018
|
+
),
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
def _get_dependencies(self) -> List[str]:
|
1022
|
+
return self._deps
|
1023
|
+
|
921
1024
|
|
922
|
-
def
|
1025
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
923
1026
|
self._model_signature_dict = dict()
|
924
1027
|
|
925
1028
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
926
1029
|
|
927
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1030
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
928
1031
|
outputs: List[BaseFeatureSpec] = []
|
929
1032
|
if hasattr(self, "predict"):
|
930
1033
|
# keep mypy happy
|
931
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1034
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
932
1035
|
# For classifier, the type of predict is the same as the type of label
|
933
|
-
if self._sklearn_object._estimator_type ==
|
934
|
-
|
1036
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1037
|
+
# label columns is the desired type for output
|
935
1038
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
936
1039
|
# rename the output columns
|
937
1040
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
938
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
939
|
-
|
940
|
-
|
1041
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1042
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1043
|
+
)
|
941
1044
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
942
1045
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
943
|
-
# Clusterer returns int64 cluster labels.
|
1046
|
+
# Clusterer returns int64 cluster labels.
|
944
1047
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
945
1048
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
946
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
947
|
-
|
948
|
-
|
949
|
-
|
1049
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1050
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1051
|
+
)
|
1052
|
+
|
950
1053
|
# For regressor, the type of predict is float64
|
951
|
-
elif self._sklearn_object._estimator_type ==
|
1054
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
952
1055
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
953
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
954
|
-
|
955
|
-
|
956
|
-
|
1056
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1057
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1058
|
+
)
|
1059
|
+
|
957
1060
|
for prob_func in PROB_FUNCTIONS:
|
958
1061
|
if hasattr(self, prob_func):
|
959
1062
|
output_cols_prefix: str = f"{prob_func}_"
|
960
1063
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
961
1064
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
962
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
963
|
-
|
964
|
-
|
1065
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1066
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1067
|
+
)
|
965
1068
|
|
966
1069
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
967
1070
|
items = list(self._model_signature_dict.items())
|
@@ -974,10 +1077,10 @@ class OneVsOneClassifier(BaseTransformer):
|
|
974
1077
|
"""Returns model signature of current class.
|
975
1078
|
|
976
1079
|
Raises:
|
977
|
-
|
1080
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
978
1081
|
|
979
1082
|
Returns:
|
980
|
-
Dict
|
1083
|
+
Dict with each method and its input output signature
|
981
1084
|
"""
|
982
1085
|
if self._model_signature_dict is None:
|
983
1086
|
raise exceptions.SnowflakeMLException(
|
@@ -985,35 +1088,3 @@ class OneVsOneClassifier(BaseTransformer):
|
|
985
1088
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
986
1089
|
)
|
987
1090
|
return self._model_signature_dict
|
988
|
-
|
989
|
-
def to_sklearn(self) -> Any:
|
990
|
-
"""Get sklearn.multiclass.OneVsOneClassifier object.
|
991
|
-
"""
|
992
|
-
if self._sklearn_object is None:
|
993
|
-
self._sklearn_object = self._create_sklearn_object()
|
994
|
-
return self._sklearn_object
|
995
|
-
|
996
|
-
def to_xgboost(self) -> Any:
|
997
|
-
raise exceptions.SnowflakeMLException(
|
998
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
999
|
-
original_exception=AttributeError(
|
1000
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1001
|
-
"to_xgboost()",
|
1002
|
-
"to_sklearn()"
|
1003
|
-
)
|
1004
|
-
),
|
1005
|
-
)
|
1006
|
-
|
1007
|
-
def to_lightgbm(self) -> Any:
|
1008
|
-
raise exceptions.SnowflakeMLException(
|
1009
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1010
|
-
original_exception=AttributeError(
|
1011
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1012
|
-
"to_lightgbm()",
|
1013
|
-
"to_sklearn()"
|
1014
|
-
)
|
1015
|
-
),
|
1016
|
-
)
|
1017
|
-
|
1018
|
-
def _get_dependencies(self) -> List[str]:
|
1019
|
-
return self._deps
|