snowflake-ml-python 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +11 -1
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/feature_store/feature_store.py +151 -78
- snowflake/ml/feature_store/feature_view.py +12 -24
- snowflake/ml/fileset/sfcfs.py +56 -50
- snowflake/ml/fileset/stage_fs.py +48 -13
- snowflake/ml/model/_client/model/model_version_impl.py +2 -50
- snowflake/ml/model/_client/ops/model_ops.py +78 -29
- snowflake/ml/model/_client/sql/model.py +23 -2
- snowflake/ml/model/_client/sql/model_version.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -2
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
- snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
- snowflake/ml/modeling/cluster/birch.py +195 -123
- snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
- snowflake/ml/modeling/cluster/dbscan.py +195 -123
- snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
- snowflake/ml/modeling/cluster/k_means.py +195 -123
- snowflake/ml/modeling/cluster/mean_shift.py +195 -123
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
- snowflake/ml/modeling/cluster/optics.py +195 -123
- snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
- snowflake/ml/modeling/compose/column_transformer.py +195 -123
- snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
- snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
- snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
- snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
- snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
- snowflake/ml/modeling/covariance/oas.py +195 -123
- snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
- snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
- snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
- snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
- snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/pca.py +195 -123
- snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
- snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
- snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
- snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
- snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
- snowflake/ml/modeling/impute/knn_imputer.py +195 -123
- snowflake/ml/modeling/impute/missing_indicator.py +195 -123
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +195 -123
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/lars.py +195 -123
- snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
- snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/perceptron.py +195 -123
- snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ridge.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
- snowflake/ml/modeling/manifold/isomap.py +195 -123
- snowflake/ml/modeling/manifold/mds.py +195 -123
- snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
- snowflake/ml/modeling/manifold/tsne.py +195 -123
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
- snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
- snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
- snowflake/ml/modeling/pipeline/pipeline.py +4 -4
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
- snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
- snowflake/ml/modeling/svm/linear_svc.py +195 -123
- snowflake/ml/modeling/svm/linear_svr.py +195 -123
- snowflake/ml/modeling/svm/nu_svc.py +195 -123
- snowflake/ml/modeling/svm/nu_svr.py +195 -123
- snowflake/ml/modeling/svm/svc.py +195 -123
- snowflake/ml/modeling/svm/svr.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +68 -57
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +202 -200
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -219,12 +218,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
219
218
|
)
|
220
219
|
return selected_cols
|
221
220
|
|
222
|
-
|
223
|
-
project=_PROJECT,
|
224
|
-
subproject=_SUBPROJECT,
|
225
|
-
custom_tags=dict([("autogen", True)]),
|
226
|
-
)
|
227
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OneVsRestClassifier":
|
221
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "OneVsRestClassifier":
|
228
222
|
"""Fit underlying estimators
|
229
223
|
For more details on this function, see [sklearn.multiclass.OneVsRestClassifier.fit]
|
230
224
|
(https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html#sklearn.multiclass.OneVsRestClassifier.fit)
|
@@ -251,12 +245,14 @@ class OneVsRestClassifier(BaseTransformer):
|
|
251
245
|
|
252
246
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
253
247
|
|
254
|
-
|
248
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
255
249
|
if SNOWML_SPROC_ENV in os.environ:
|
256
250
|
statement_params = telemetry.get_function_usage_statement_params(
|
257
251
|
project=_PROJECT,
|
258
252
|
subproject=_SUBPROJECT,
|
259
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
253
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
254
|
+
inspect.currentframe(), OneVsRestClassifier.__class__.__name__
|
255
|
+
),
|
260
256
|
api_calls=[Session.call],
|
261
257
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
262
258
|
)
|
@@ -277,7 +273,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
277
273
|
)
|
278
274
|
self._sklearn_object = model_trainer.train()
|
279
275
|
self._is_fitted = True
|
280
|
-
self.
|
276
|
+
self._generate_model_signatures(dataset)
|
281
277
|
return self
|
282
278
|
|
283
279
|
def _batch_inference_validate_snowpark(
|
@@ -353,7 +349,9 @@ class OneVsRestClassifier(BaseTransformer):
|
|
353
349
|
# when it is classifier, infer the datatype from label columns
|
354
350
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
355
351
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
356
|
-
label_cols_signatures = [
|
352
|
+
label_cols_signatures = [
|
353
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
354
|
+
]
|
357
355
|
if len(label_cols_signatures) == 0:
|
358
356
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
359
357
|
raise exceptions.SnowflakeMLException(
|
@@ -361,25 +359,22 @@ class OneVsRestClassifier(BaseTransformer):
|
|
361
359
|
original_exception=ValueError(error_str),
|
362
360
|
)
|
363
361
|
|
364
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
365
|
-
label_cols_signatures[0].as_snowpark_type()
|
366
|
-
)
|
362
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
367
363
|
|
368
364
|
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
369
|
-
assert isinstance(
|
365
|
+
assert isinstance(
|
366
|
+
dataset._session, Session
|
367
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
370
368
|
|
371
369
|
transform_kwargs = dict(
|
372
|
-
session
|
373
|
-
dependencies
|
374
|
-
drop_input_cols
|
375
|
-
expected_output_cols_type
|
370
|
+
session=dataset._session,
|
371
|
+
dependencies=self._deps,
|
372
|
+
drop_input_cols=self._drop_input_cols,
|
373
|
+
expected_output_cols_type=expected_type_inferred,
|
376
374
|
)
|
377
375
|
|
378
376
|
elif isinstance(dataset, pd.DataFrame):
|
379
|
-
transform_kwargs = dict(
|
380
|
-
snowpark_input_cols = self._snowpark_cols,
|
381
|
-
drop_input_cols = self._drop_input_cols
|
382
|
-
)
|
377
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
383
378
|
|
384
379
|
transform_handlers = ModelTransformerBuilder.build(
|
385
380
|
dataset=dataset,
|
@@ -419,7 +414,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
419
414
|
Transformed dataset.
|
420
415
|
"""
|
421
416
|
super()._check_dataset_type(dataset)
|
422
|
-
inference_method="transform"
|
417
|
+
inference_method = "transform"
|
423
418
|
|
424
419
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
425
420
|
# are specific to the type of dataset used.
|
@@ -456,17 +451,14 @@ class OneVsRestClassifier(BaseTransformer):
|
|
456
451
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
457
452
|
|
458
453
|
transform_kwargs = dict(
|
459
|
-
session
|
460
|
-
dependencies
|
461
|
-
drop_input_cols
|
462
|
-
expected_output_cols_type
|
454
|
+
session=dataset._session,
|
455
|
+
dependencies=self._deps,
|
456
|
+
drop_input_cols=self._drop_input_cols,
|
457
|
+
expected_output_cols_type=expected_dtype,
|
463
458
|
)
|
464
459
|
|
465
460
|
elif isinstance(dataset, pd.DataFrame):
|
466
|
-
transform_kwargs = dict(
|
467
|
-
snowpark_input_cols = self._snowpark_cols,
|
468
|
-
drop_input_cols = self._drop_input_cols
|
469
|
-
)
|
461
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
470
462
|
|
471
463
|
transform_handlers = ModelTransformerBuilder.build(
|
472
464
|
dataset=dataset,
|
@@ -485,7 +477,11 @@ class OneVsRestClassifier(BaseTransformer):
|
|
485
477
|
return output_df
|
486
478
|
|
487
479
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
488
|
-
def fit_predict(
|
480
|
+
def fit_predict(
|
481
|
+
self,
|
482
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
483
|
+
output_cols_prefix: str = "fit_predict_",
|
484
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
489
485
|
""" Method not supported for this class.
|
490
486
|
|
491
487
|
|
@@ -510,7 +506,9 @@ class OneVsRestClassifier(BaseTransformer):
|
|
510
506
|
)
|
511
507
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
512
508
|
drop_input_cols=self._drop_input_cols,
|
513
|
-
expected_output_cols_list=
|
509
|
+
expected_output_cols_list=(
|
510
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
511
|
+
),
|
514
512
|
)
|
515
513
|
self._sklearn_object = fitted_estimator
|
516
514
|
self._is_fitted = True
|
@@ -527,6 +525,62 @@ class OneVsRestClassifier(BaseTransformer):
|
|
527
525
|
assert self._sklearn_object is not None
|
528
526
|
return self._sklearn_object.embedding_
|
529
527
|
|
528
|
+
|
529
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
530
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
531
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
532
|
+
"""
|
533
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
534
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
535
|
+
if output_cols:
|
536
|
+
output_cols = [
|
537
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
538
|
+
for c in output_cols
|
539
|
+
]
|
540
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
541
|
+
output_cols = [output_cols_prefix]
|
542
|
+
elif self._sklearn_object is not None:
|
543
|
+
classes = self._sklearn_object.classes_
|
544
|
+
if isinstance(classes, numpy.ndarray):
|
545
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
546
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
547
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
548
|
+
output_cols = []
|
549
|
+
for i, cl in enumerate(classes):
|
550
|
+
# For binary classification, there is only one output column for each class
|
551
|
+
# ndarray as the two classes are complementary.
|
552
|
+
if len(cl) == 2:
|
553
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
554
|
+
else:
|
555
|
+
output_cols.extend([
|
556
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
557
|
+
])
|
558
|
+
else:
|
559
|
+
output_cols = []
|
560
|
+
|
561
|
+
# Make sure column names are valid snowflake identifiers.
|
562
|
+
assert output_cols is not None # Make MyPy happy
|
563
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
564
|
+
|
565
|
+
return rv
|
566
|
+
|
567
|
+
def _align_expected_output_names(
|
568
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
569
|
+
) -> List[str]:
|
570
|
+
# in case the inferred output column names dimension is different
|
571
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
572
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
573
|
+
output_df_columns = list(output_df_pd.columns)
|
574
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
575
|
+
if self.sample_weight_col:
|
576
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
577
|
+
# if the dimension of inferred output column names is correct; use it
|
578
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
579
|
+
return expected_output_cols_list
|
580
|
+
# otherwise, use the sklearn estimator's output
|
581
|
+
else:
|
582
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
583
|
+
|
530
584
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
531
585
|
@telemetry.send_api_usage_telemetry(
|
532
586
|
project=_PROJECT,
|
@@ -559,24 +613,28 @@ class OneVsRestClassifier(BaseTransformer):
|
|
559
613
|
# are specific to the type of dataset used.
|
560
614
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
561
615
|
|
616
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
617
|
+
|
562
618
|
if isinstance(dataset, DataFrame):
|
563
619
|
self._deps = self._batch_inference_validate_snowpark(
|
564
620
|
dataset=dataset,
|
565
621
|
inference_method=inference_method,
|
566
622
|
)
|
567
|
-
assert isinstance(
|
623
|
+
assert isinstance(
|
624
|
+
dataset._session, Session
|
625
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
568
626
|
transform_kwargs = dict(
|
569
627
|
session=dataset._session,
|
570
628
|
dependencies=self._deps,
|
571
|
-
drop_input_cols
|
629
|
+
drop_input_cols=self._drop_input_cols,
|
572
630
|
expected_output_cols_type="float",
|
573
631
|
)
|
632
|
+
expected_output_cols = self._align_expected_output_names(
|
633
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
634
|
+
)
|
574
635
|
|
575
636
|
elif isinstance(dataset, pd.DataFrame):
|
576
|
-
transform_kwargs = dict(
|
577
|
-
snowpark_input_cols = self._snowpark_cols,
|
578
|
-
drop_input_cols = self._drop_input_cols
|
579
|
-
)
|
637
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
580
638
|
|
581
639
|
transform_handlers = ModelTransformerBuilder.build(
|
582
640
|
dataset=dataset,
|
@@ -588,7 +646,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
588
646
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
589
647
|
inference_method=inference_method,
|
590
648
|
input_cols=self.input_cols,
|
591
|
-
expected_output_cols=
|
649
|
+
expected_output_cols=expected_output_cols,
|
592
650
|
**transform_kwargs
|
593
651
|
)
|
594
652
|
return output_df
|
@@ -620,7 +678,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
620
678
|
Output dataset with log probability of the sample for each class in the model.
|
621
679
|
"""
|
622
680
|
super()._check_dataset_type(dataset)
|
623
|
-
inference_method="predict_log_proba"
|
681
|
+
inference_method = "predict_log_proba"
|
682
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
624
683
|
|
625
684
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
626
685
|
# are specific to the type of dataset used.
|
@@ -631,18 +690,20 @@ class OneVsRestClassifier(BaseTransformer):
|
|
631
690
|
dataset=dataset,
|
632
691
|
inference_method=inference_method,
|
633
692
|
)
|
634
|
-
assert isinstance(
|
693
|
+
assert isinstance(
|
694
|
+
dataset._session, Session
|
695
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
635
696
|
transform_kwargs = dict(
|
636
697
|
session=dataset._session,
|
637
698
|
dependencies=self._deps,
|
638
|
-
drop_input_cols
|
699
|
+
drop_input_cols=self._drop_input_cols,
|
639
700
|
expected_output_cols_type="float",
|
640
701
|
)
|
702
|
+
expected_output_cols = self._align_expected_output_names(
|
703
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
704
|
+
)
|
641
705
|
elif isinstance(dataset, pd.DataFrame):
|
642
|
-
transform_kwargs = dict(
|
643
|
-
snowpark_input_cols = self._snowpark_cols,
|
644
|
-
drop_input_cols = self._drop_input_cols
|
645
|
-
)
|
706
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
646
707
|
|
647
708
|
transform_handlers = ModelTransformerBuilder.build(
|
648
709
|
dataset=dataset,
|
@@ -655,7 +716,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
655
716
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
656
717
|
inference_method=inference_method,
|
657
718
|
input_cols=self.input_cols,
|
658
|
-
expected_output_cols=
|
719
|
+
expected_output_cols=expected_output_cols,
|
659
720
|
**transform_kwargs
|
660
721
|
)
|
661
722
|
return output_df
|
@@ -683,30 +744,34 @@ class OneVsRestClassifier(BaseTransformer):
|
|
683
744
|
Output dataset with results of the decision function for the samples in input dataset.
|
684
745
|
"""
|
685
746
|
super()._check_dataset_type(dataset)
|
686
|
-
inference_method="decision_function"
|
747
|
+
inference_method = "decision_function"
|
687
748
|
|
688
749
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
689
750
|
# are specific to the type of dataset used.
|
690
751
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
691
752
|
|
753
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
754
|
+
|
692
755
|
if isinstance(dataset, DataFrame):
|
693
756
|
self._deps = self._batch_inference_validate_snowpark(
|
694
757
|
dataset=dataset,
|
695
758
|
inference_method=inference_method,
|
696
759
|
)
|
697
|
-
assert isinstance(
|
760
|
+
assert isinstance(
|
761
|
+
dataset._session, Session
|
762
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
698
763
|
transform_kwargs = dict(
|
699
764
|
session=dataset._session,
|
700
765
|
dependencies=self._deps,
|
701
|
-
drop_input_cols
|
766
|
+
drop_input_cols=self._drop_input_cols,
|
702
767
|
expected_output_cols_type="float",
|
703
768
|
)
|
769
|
+
expected_output_cols = self._align_expected_output_names(
|
770
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
771
|
+
)
|
704
772
|
|
705
773
|
elif isinstance(dataset, pd.DataFrame):
|
706
|
-
transform_kwargs = dict(
|
707
|
-
snowpark_input_cols = self._snowpark_cols,
|
708
|
-
drop_input_cols = self._drop_input_cols
|
709
|
-
)
|
774
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
710
775
|
|
711
776
|
transform_handlers = ModelTransformerBuilder.build(
|
712
777
|
dataset=dataset,
|
@@ -719,7 +784,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
719
784
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
720
785
|
inference_method=inference_method,
|
721
786
|
input_cols=self.input_cols,
|
722
|
-
expected_output_cols=
|
787
|
+
expected_output_cols=expected_output_cols,
|
723
788
|
**transform_kwargs
|
724
789
|
)
|
725
790
|
return output_df
|
@@ -748,12 +813,14 @@ class OneVsRestClassifier(BaseTransformer):
|
|
748
813
|
Output dataset with probability of the sample for each class in the model.
|
749
814
|
"""
|
750
815
|
super()._check_dataset_type(dataset)
|
751
|
-
inference_method="score_samples"
|
816
|
+
inference_method = "score_samples"
|
752
817
|
|
753
818
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
754
819
|
# are specific to the type of dataset used.
|
755
820
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
756
821
|
|
822
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
823
|
+
|
757
824
|
if isinstance(dataset, DataFrame):
|
758
825
|
self._deps = self._batch_inference_validate_snowpark(
|
759
826
|
dataset=dataset,
|
@@ -766,6 +833,9 @@ class OneVsRestClassifier(BaseTransformer):
|
|
766
833
|
drop_input_cols = self._drop_input_cols,
|
767
834
|
expected_output_cols_type="float",
|
768
835
|
)
|
836
|
+
expected_output_cols = self._align_expected_output_names(
|
837
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
838
|
+
)
|
769
839
|
|
770
840
|
elif isinstance(dataset, pd.DataFrame):
|
771
841
|
transform_kwargs = dict(
|
@@ -784,7 +854,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
784
854
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
785
855
|
inference_method=inference_method,
|
786
856
|
input_cols=self.input_cols,
|
787
|
-
expected_output_cols=
|
857
|
+
expected_output_cols=expected_output_cols,
|
788
858
|
**transform_kwargs
|
789
859
|
)
|
790
860
|
return output_df
|
@@ -931,50 +1001,84 @@ class OneVsRestClassifier(BaseTransformer):
|
|
931
1001
|
)
|
932
1002
|
return output_df
|
933
1003
|
|
1004
|
+
|
1005
|
+
|
1006
|
+
def to_sklearn(self) -> Any:
|
1007
|
+
"""Get sklearn.multiclass.OneVsRestClassifier object.
|
1008
|
+
"""
|
1009
|
+
if self._sklearn_object is None:
|
1010
|
+
self._sklearn_object = self._create_sklearn_object()
|
1011
|
+
return self._sklearn_object
|
1012
|
+
|
1013
|
+
def to_xgboost(self) -> Any:
|
1014
|
+
raise exceptions.SnowflakeMLException(
|
1015
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1016
|
+
original_exception=AttributeError(
|
1017
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1018
|
+
"to_xgboost()",
|
1019
|
+
"to_sklearn()"
|
1020
|
+
)
|
1021
|
+
),
|
1022
|
+
)
|
1023
|
+
|
1024
|
+
def to_lightgbm(self) -> Any:
|
1025
|
+
raise exceptions.SnowflakeMLException(
|
1026
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1027
|
+
original_exception=AttributeError(
|
1028
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1029
|
+
"to_lightgbm()",
|
1030
|
+
"to_sklearn()"
|
1031
|
+
)
|
1032
|
+
),
|
1033
|
+
)
|
934
1034
|
|
935
|
-
def
|
1035
|
+
def _get_dependencies(self) -> List[str]:
|
1036
|
+
return self._deps
|
1037
|
+
|
1038
|
+
|
1039
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
936
1040
|
self._model_signature_dict = dict()
|
937
1041
|
|
938
1042
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
939
1043
|
|
940
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1044
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
941
1045
|
outputs: List[BaseFeatureSpec] = []
|
942
1046
|
if hasattr(self, "predict"):
|
943
1047
|
# keep mypy happy
|
944
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1048
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
945
1049
|
# For classifier, the type of predict is the same as the type of label
|
946
|
-
if self._sklearn_object._estimator_type ==
|
947
|
-
|
1050
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1051
|
+
# label columns is the desired type for output
|
948
1052
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
949
1053
|
# rename the output columns
|
950
1054
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
951
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
952
|
-
|
953
|
-
|
1055
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1056
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1057
|
+
)
|
954
1058
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
955
1059
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
956
|
-
# Clusterer returns int64 cluster labels.
|
1060
|
+
# Clusterer returns int64 cluster labels.
|
957
1061
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
958
1062
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
959
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
960
|
-
|
961
|
-
|
962
|
-
|
1063
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1064
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1065
|
+
)
|
1066
|
+
|
963
1067
|
# For regressor, the type of predict is float64
|
964
|
-
elif self._sklearn_object._estimator_type ==
|
1068
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
965
1069
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
966
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
967
|
-
|
968
|
-
|
969
|
-
|
1070
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1071
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1072
|
+
)
|
1073
|
+
|
970
1074
|
for prob_func in PROB_FUNCTIONS:
|
971
1075
|
if hasattr(self, prob_func):
|
972
1076
|
output_cols_prefix: str = f"{prob_func}_"
|
973
1077
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
974
1078
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
975
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
976
|
-
|
977
|
-
|
1079
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1080
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1081
|
+
)
|
978
1082
|
|
979
1083
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
980
1084
|
items = list(self._model_signature_dict.items())
|
@@ -987,10 +1091,10 @@ class OneVsRestClassifier(BaseTransformer):
|
|
987
1091
|
"""Returns model signature of current class.
|
988
1092
|
|
989
1093
|
Raises:
|
990
|
-
|
1094
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
991
1095
|
|
992
1096
|
Returns:
|
993
|
-
Dict
|
1097
|
+
Dict with each method and its input output signature
|
994
1098
|
"""
|
995
1099
|
if self._model_signature_dict is None:
|
996
1100
|
raise exceptions.SnowflakeMLException(
|
@@ -998,35 +1102,3 @@ class OneVsRestClassifier(BaseTransformer):
|
|
998
1102
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
999
1103
|
)
|
1000
1104
|
return self._model_signature_dict
|
1001
|
-
|
1002
|
-
def to_sklearn(self) -> Any:
|
1003
|
-
"""Get sklearn.multiclass.OneVsRestClassifier object.
|
1004
|
-
"""
|
1005
|
-
if self._sklearn_object is None:
|
1006
|
-
self._sklearn_object = self._create_sklearn_object()
|
1007
|
-
return self._sklearn_object
|
1008
|
-
|
1009
|
-
def to_xgboost(self) -> Any:
|
1010
|
-
raise exceptions.SnowflakeMLException(
|
1011
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1012
|
-
original_exception=AttributeError(
|
1013
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1014
|
-
"to_xgboost()",
|
1015
|
-
"to_sklearn()"
|
1016
|
-
)
|
1017
|
-
),
|
1018
|
-
)
|
1019
|
-
|
1020
|
-
def to_lightgbm(self) -> Any:
|
1021
|
-
raise exceptions.SnowflakeMLException(
|
1022
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1023
|
-
original_exception=AttributeError(
|
1024
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1025
|
-
"to_lightgbm()",
|
1026
|
-
"to_sklearn()"
|
1027
|
-
)
|
1028
|
-
),
|
1029
|
-
)
|
1030
|
-
|
1031
|
-
def _get_dependencies(self) -> List[str]:
|
1032
|
-
return self._deps
|