snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -224,6 +224,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
224
224
|
expensive and is not strictly required to select the parameters that
|
225
225
|
yield the best generalization performance.
|
226
226
|
"""
|
227
|
+
|
227
228
|
_ENABLE_DISTRIBUTED = True
|
228
229
|
|
229
230
|
def __init__( # type: ignore[no-untyped-def]
|
@@ -345,13 +346,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
345
346
|
self._get_model_signatures(dataset)
|
346
347
|
return self
|
347
348
|
|
348
|
-
def
|
349
|
-
if self._drop_input_cols:
|
350
|
-
return []
|
351
|
-
else:
|
352
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
353
|
-
|
354
|
-
def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> None:
|
349
|
+
def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
|
355
350
|
"""Util method to run validate that batch inference can be run on a snowpark dataframe."""
|
356
351
|
if not self._is_fitted:
|
357
352
|
raise exceptions.SnowflakeMLException(
|
@@ -368,7 +363,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
368
363
|
original_exception=ValueError("Session must not specified for snowpark dataset."),
|
369
364
|
)
|
370
365
|
# Validate that key package version in user workspace are supported in snowflake conda channel
|
371
|
-
pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
366
|
+
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
372
367
|
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
373
368
|
)
|
374
369
|
|
@@ -403,7 +398,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
403
398
|
expected_type_inferred = convert_sp_to_sf_type(
|
404
399
|
self.model_signatures["predict"].outputs[0].as_snowpark_type()
|
405
400
|
)
|
406
|
-
self._batch_inference_validate_snowpark(
|
401
|
+
self._deps = self._batch_inference_validate_snowpark(
|
407
402
|
dataset=dataset,
|
408
403
|
inference_method=inference_method,
|
409
404
|
)
|
@@ -412,8 +407,8 @@ class RandomizedSearchCV(BaseTransformer):
|
|
412
407
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
413
408
|
transform_kwargs = dict(
|
414
409
|
session=dataset._session,
|
415
|
-
dependencies=self.
|
416
|
-
|
410
|
+
dependencies=self._deps,
|
411
|
+
drop_input_cols=self._drop_input_cols,
|
417
412
|
expected_output_cols_type=expected_type_inferred,
|
418
413
|
)
|
419
414
|
|
@@ -462,14 +457,14 @@ class RandomizedSearchCV(BaseTransformer):
|
|
462
457
|
inference_method = "transform"
|
463
458
|
|
464
459
|
if isinstance(dataset, DataFrame):
|
465
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
460
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
466
461
|
assert isinstance(
|
467
462
|
dataset._session, Session
|
468
463
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
469
464
|
transform_kwargs = dict(
|
470
465
|
session=dataset._session,
|
471
|
-
dependencies=self.
|
472
|
-
|
466
|
+
dependencies=self._deps,
|
467
|
+
drop_input_cols=self._drop_input_cols,
|
473
468
|
)
|
474
469
|
|
475
470
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -491,36 +486,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
491
486
|
)
|
492
487
|
return output_df
|
493
488
|
|
494
|
-
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
495
|
-
"""Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
496
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
497
|
-
|
498
|
-
Args:
|
499
|
-
output_cols_prefix (str): prefix according to the function
|
500
|
-
|
501
|
-
Returns:
|
502
|
-
List[str]: output cols with prefix
|
503
|
-
"""
|
504
|
-
if getattr(self._sklearn_object, "classes_", None) is None:
|
505
|
-
return [output_cols_prefix]
|
506
|
-
|
507
|
-
assert self._sklearn_object is not None # keep mypy happy
|
508
|
-
classes = self._sklearn_object.classes_
|
509
|
-
if isinstance(classes, np.ndarray):
|
510
|
-
return [f"{output_cols_prefix}{c}" for c in classes.tolist()]
|
511
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], np.ndarray):
|
512
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
513
|
-
output_cols = []
|
514
|
-
for i, cl in enumerate(classes):
|
515
|
-
# For binary classification, there is only one output column for each class
|
516
|
-
# ndarray as the two classes are complementary.
|
517
|
-
if len(cl) == 2:
|
518
|
-
output_cols.append(f"{output_cols_prefix}_{i}_{cl[0]}")
|
519
|
-
else:
|
520
|
-
output_cols.extend([f"{output_cols_prefix}_{i}_{c}" for c in cl.tolist()])
|
521
|
-
return output_cols
|
522
|
-
return []
|
523
|
-
|
524
489
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
525
490
|
@telemetry.send_api_usage_telemetry(
|
526
491
|
project=_PROJECT,
|
@@ -550,14 +515,14 @@ class RandomizedSearchCV(BaseTransformer):
|
|
550
515
|
inference_method = "predict_proba"
|
551
516
|
|
552
517
|
if isinstance(dataset, DataFrame):
|
553
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
518
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
554
519
|
assert isinstance(
|
555
520
|
dataset._session, Session
|
556
521
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
557
522
|
transform_kwargs = dict(
|
558
523
|
session=dataset._session,
|
559
|
-
dependencies=self.
|
560
|
-
|
524
|
+
dependencies=self._deps,
|
525
|
+
drop_input_cols=self._drop_input_cols,
|
561
526
|
expected_output_cols_type="float",
|
562
527
|
)
|
563
528
|
|
@@ -610,14 +575,14 @@ class RandomizedSearchCV(BaseTransformer):
|
|
610
575
|
inference_method = "predict_log_proba"
|
611
576
|
|
612
577
|
if isinstance(dataset, DataFrame):
|
613
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
578
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
614
579
|
assert isinstance(
|
615
580
|
dataset._session, Session
|
616
581
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
617
582
|
transform_kwargs = dict(
|
618
583
|
session=dataset._session,
|
619
|
-
dependencies=self.
|
620
|
-
|
584
|
+
dependencies=self._deps,
|
585
|
+
drop_input_cols=self._drop_input_cols,
|
621
586
|
expected_output_cols_type="float",
|
622
587
|
)
|
623
588
|
|
@@ -669,14 +634,14 @@ class RandomizedSearchCV(BaseTransformer):
|
|
669
634
|
inference_method = "decision_function"
|
670
635
|
|
671
636
|
if isinstance(dataset, DataFrame):
|
672
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
637
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
673
638
|
assert isinstance(
|
674
639
|
dataset._session, Session
|
675
640
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
676
641
|
transform_kwargs = dict(
|
677
642
|
session=dataset._session,
|
678
|
-
dependencies=self.
|
679
|
-
|
643
|
+
dependencies=self._deps,
|
644
|
+
drop_input_cols=self._drop_input_cols,
|
680
645
|
expected_output_cols_type="float",
|
681
646
|
)
|
682
647
|
|
@@ -730,14 +695,14 @@ class RandomizedSearchCV(BaseTransformer):
|
|
730
695
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
731
696
|
|
732
697
|
if isinstance(dataset, DataFrame):
|
733
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
698
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
734
699
|
assert isinstance(
|
735
700
|
dataset._session, Session
|
736
701
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
737
702
|
transform_kwargs = dict(
|
738
703
|
session=dataset._session,
|
739
|
-
dependencies=self.
|
740
|
-
|
704
|
+
dependencies=self._deps,
|
705
|
+
drop_input_cols=self._drop_input_cols,
|
741
706
|
expected_output_cols_type="float",
|
742
707
|
)
|
743
708
|
|
@@ -780,6 +745,10 @@ class RandomizedSearchCV(BaseTransformer):
|
|
780
745
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
781
746
|
|
782
747
|
if isinstance(dataset, DataFrame):
|
748
|
+
self._deps = self._batch_inference_validate_snowpark(
|
749
|
+
dataset=dataset,
|
750
|
+
inference_method="score",
|
751
|
+
)
|
783
752
|
selected_cols = self._get_active_columns()
|
784
753
|
if len(selected_cols) > 0:
|
785
754
|
dataset = dataset.select(selected_cols)
|
@@ -787,7 +756,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
787
756
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
788
757
|
transform_kwargs = dict(
|
789
758
|
session=dataset._session,
|
790
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
759
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
791
760
|
score_sproc_imports=["sklearn"],
|
792
761
|
)
|
793
762
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -271,18 +271,24 @@ class OneVsOneClassifier(BaseTransformer):
|
|
271
271
|
self._get_model_signatures(dataset)
|
272
272
|
return self
|
273
273
|
|
274
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
275
|
-
if self._drop_input_cols:
|
276
|
-
return []
|
277
|
-
else:
|
278
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
279
|
-
|
280
274
|
def _batch_inference_validate_snowpark(
|
281
275
|
self,
|
282
276
|
dataset: DataFrame,
|
283
277
|
inference_method: str,
|
284
278
|
) -> List[str]:
|
285
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
279
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
280
|
+
return the available package that exists in the snowflake anaconda channel
|
281
|
+
|
282
|
+
Args:
|
283
|
+
dataset: snowpark dataframe
|
284
|
+
inference_method: the inference method such as predict, score...
|
285
|
+
|
286
|
+
Raises:
|
287
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
288
|
+
SnowflakeMLException: If the session is None, raise error
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
A list of available package that exists in the snowflake anaconda channel
|
286
292
|
"""
|
287
293
|
if not self._is_fitted:
|
288
294
|
raise exceptions.SnowflakeMLException(
|
@@ -356,7 +362,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
356
362
|
transform_kwargs = dict(
|
357
363
|
session = dataset._session,
|
358
364
|
dependencies = self._deps,
|
359
|
-
|
365
|
+
drop_input_cols = self._drop_input_cols,
|
360
366
|
expected_output_cols_type = expected_type_inferred,
|
361
367
|
)
|
362
368
|
|
@@ -416,16 +422,16 @@ class OneVsOneClassifier(BaseTransformer):
|
|
416
422
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
417
423
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
418
424
|
# each row containing a list of values.
|
419
|
-
expected_dtype = "
|
425
|
+
expected_dtype = "array"
|
420
426
|
|
421
427
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
422
428
|
if expected_dtype == "":
|
423
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
429
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
424
430
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
425
|
-
expected_dtype = "
|
426
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
431
|
+
expected_dtype = "array"
|
432
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
427
433
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
428
|
-
expected_dtype = "
|
434
|
+
expected_dtype = "array"
|
429
435
|
else:
|
430
436
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
431
437
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -443,7 +449,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
443
449
|
transform_kwargs = dict(
|
444
450
|
session = dataset._session,
|
445
451
|
dependencies = self._deps,
|
446
|
-
|
452
|
+
drop_input_cols = self._drop_input_cols,
|
447
453
|
expected_output_cols_type = expected_dtype,
|
448
454
|
)
|
449
455
|
|
@@ -494,7 +500,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
494
500
|
subproject=_SUBPROJECT,
|
495
501
|
)
|
496
502
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
497
|
-
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
498
504
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
499
505
|
)
|
500
506
|
self._sklearn_object = fitted_estimator
|
@@ -512,44 +518,6 @@ class OneVsOneClassifier(BaseTransformer):
|
|
512
518
|
assert self._sklearn_object is not None
|
513
519
|
return self._sklearn_object.embedding_
|
514
520
|
|
515
|
-
|
516
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
517
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
518
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
519
|
-
"""
|
520
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
521
|
-
if output_cols:
|
522
|
-
output_cols = [
|
523
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
524
|
-
for c in output_cols
|
525
|
-
]
|
526
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
527
|
-
output_cols = [output_cols_prefix]
|
528
|
-
elif self._sklearn_object is not None:
|
529
|
-
classes = self._sklearn_object.classes_
|
530
|
-
if isinstance(classes, numpy.ndarray):
|
531
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
532
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
533
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
534
|
-
output_cols = []
|
535
|
-
for i, cl in enumerate(classes):
|
536
|
-
# For binary classification, there is only one output column for each class
|
537
|
-
# ndarray as the two classes are complementary.
|
538
|
-
if len(cl) == 2:
|
539
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
540
|
-
else:
|
541
|
-
output_cols.extend([
|
542
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
543
|
-
])
|
544
|
-
else:
|
545
|
-
output_cols = []
|
546
|
-
|
547
|
-
# Make sure column names are valid snowflake identifiers.
|
548
|
-
assert output_cols is not None # Make MyPy happy
|
549
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
550
|
-
|
551
|
-
return rv
|
552
|
-
|
553
521
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
554
522
|
@telemetry.send_api_usage_telemetry(
|
555
523
|
project=_PROJECT,
|
@@ -589,7 +557,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
589
557
|
transform_kwargs = dict(
|
590
558
|
session=dataset._session,
|
591
559
|
dependencies=self._deps,
|
592
|
-
|
560
|
+
drop_input_cols = self._drop_input_cols,
|
593
561
|
expected_output_cols_type="float",
|
594
562
|
)
|
595
563
|
|
@@ -654,7 +622,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
654
622
|
transform_kwargs = dict(
|
655
623
|
session=dataset._session,
|
656
624
|
dependencies=self._deps,
|
657
|
-
|
625
|
+
drop_input_cols = self._drop_input_cols,
|
658
626
|
expected_output_cols_type="float",
|
659
627
|
)
|
660
628
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -717,7 +685,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
717
685
|
transform_kwargs = dict(
|
718
686
|
session=dataset._session,
|
719
687
|
dependencies=self._deps,
|
720
|
-
|
688
|
+
drop_input_cols = self._drop_input_cols,
|
721
689
|
expected_output_cols_type="float",
|
722
690
|
)
|
723
691
|
|
@@ -782,7 +750,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
782
750
|
transform_kwargs = dict(
|
783
751
|
session=dataset._session,
|
784
752
|
dependencies=self._deps,
|
785
|
-
|
753
|
+
drop_input_cols = self._drop_input_cols,
|
786
754
|
expected_output_cols_type="float",
|
787
755
|
)
|
788
756
|
|
@@ -838,13 +806,17 @@ class OneVsOneClassifier(BaseTransformer):
|
|
838
806
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
839
807
|
|
840
808
|
if isinstance(dataset, DataFrame):
|
809
|
+
self._deps = self._batch_inference_validate_snowpark(
|
810
|
+
dataset=dataset,
|
811
|
+
inference_method="score",
|
812
|
+
)
|
841
813
|
selected_cols = self._get_active_columns()
|
842
814
|
if len(selected_cols) > 0:
|
843
815
|
dataset = dataset.select(selected_cols)
|
844
816
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
845
817
|
transform_kwargs = dict(
|
846
818
|
session=dataset._session,
|
847
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
819
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
848
820
|
score_sproc_imports=['sklearn'],
|
849
821
|
)
|
850
822
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -918,9 +890,9 @@ class OneVsOneClassifier(BaseTransformer):
|
|
918
890
|
transform_kwargs = dict(
|
919
891
|
session = dataset._session,
|
920
892
|
dependencies = self._deps,
|
921
|
-
|
922
|
-
expected_output_cols_type
|
923
|
-
n_neighbors =
|
893
|
+
drop_input_cols = self._drop_input_cols,
|
894
|
+
expected_output_cols_type="array",
|
895
|
+
n_neighbors = n_neighbors,
|
924
896
|
return_distance = return_distance
|
925
897
|
)
|
926
898
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -280,18 +280,24 @@ class OneVsRestClassifier(BaseTransformer):
|
|
280
280
|
self._get_model_signatures(dataset)
|
281
281
|
return self
|
282
282
|
|
283
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
284
|
-
if self._drop_input_cols:
|
285
|
-
return []
|
286
|
-
else:
|
287
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
288
|
-
|
289
283
|
def _batch_inference_validate_snowpark(
|
290
284
|
self,
|
291
285
|
dataset: DataFrame,
|
292
286
|
inference_method: str,
|
293
287
|
) -> List[str]:
|
294
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
288
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
289
|
+
return the available package that exists in the snowflake anaconda channel
|
290
|
+
|
291
|
+
Args:
|
292
|
+
dataset: snowpark dataframe
|
293
|
+
inference_method: the inference method such as predict, score...
|
294
|
+
|
295
|
+
Raises:
|
296
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
297
|
+
SnowflakeMLException: If the session is None, raise error
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
A list of available package that exists in the snowflake anaconda channel
|
295
301
|
"""
|
296
302
|
if not self._is_fitted:
|
297
303
|
raise exceptions.SnowflakeMLException(
|
@@ -365,7 +371,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
365
371
|
transform_kwargs = dict(
|
366
372
|
session = dataset._session,
|
367
373
|
dependencies = self._deps,
|
368
|
-
|
374
|
+
drop_input_cols = self._drop_input_cols,
|
369
375
|
expected_output_cols_type = expected_type_inferred,
|
370
376
|
)
|
371
377
|
|
@@ -425,16 +431,16 @@ class OneVsRestClassifier(BaseTransformer):
|
|
425
431
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
426
432
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
427
433
|
# each row containing a list of values.
|
428
|
-
expected_dtype = "
|
434
|
+
expected_dtype = "array"
|
429
435
|
|
430
436
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
431
437
|
if expected_dtype == "":
|
432
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
438
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
433
439
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
434
|
-
expected_dtype = "
|
435
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
440
|
+
expected_dtype = "array"
|
441
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
436
442
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
437
|
-
expected_dtype = "
|
443
|
+
expected_dtype = "array"
|
438
444
|
else:
|
439
445
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
440
446
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -452,7 +458,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
452
458
|
transform_kwargs = dict(
|
453
459
|
session = dataset._session,
|
454
460
|
dependencies = self._deps,
|
455
|
-
|
461
|
+
drop_input_cols = self._drop_input_cols,
|
456
462
|
expected_output_cols_type = expected_dtype,
|
457
463
|
)
|
458
464
|
|
@@ -503,7 +509,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
503
509
|
subproject=_SUBPROJECT,
|
504
510
|
)
|
505
511
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
506
|
-
|
512
|
+
drop_input_cols=self._drop_input_cols,
|
507
513
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
508
514
|
)
|
509
515
|
self._sklearn_object = fitted_estimator
|
@@ -521,44 +527,6 @@ class OneVsRestClassifier(BaseTransformer):
|
|
521
527
|
assert self._sklearn_object is not None
|
522
528
|
return self._sklearn_object.embedding_
|
523
529
|
|
524
|
-
|
525
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
526
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
527
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
528
|
-
"""
|
529
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
530
|
-
if output_cols:
|
531
|
-
output_cols = [
|
532
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
533
|
-
for c in output_cols
|
534
|
-
]
|
535
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
536
|
-
output_cols = [output_cols_prefix]
|
537
|
-
elif self._sklearn_object is not None:
|
538
|
-
classes = self._sklearn_object.classes_
|
539
|
-
if isinstance(classes, numpy.ndarray):
|
540
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
541
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
542
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
543
|
-
output_cols = []
|
544
|
-
for i, cl in enumerate(classes):
|
545
|
-
# For binary classification, there is only one output column for each class
|
546
|
-
# ndarray as the two classes are complementary.
|
547
|
-
if len(cl) == 2:
|
548
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
549
|
-
else:
|
550
|
-
output_cols.extend([
|
551
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
552
|
-
])
|
553
|
-
else:
|
554
|
-
output_cols = []
|
555
|
-
|
556
|
-
# Make sure column names are valid snowflake identifiers.
|
557
|
-
assert output_cols is not None # Make MyPy happy
|
558
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
559
|
-
|
560
|
-
return rv
|
561
|
-
|
562
530
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
563
531
|
@telemetry.send_api_usage_telemetry(
|
564
532
|
project=_PROJECT,
|
@@ -600,7 +568,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
600
568
|
transform_kwargs = dict(
|
601
569
|
session=dataset._session,
|
602
570
|
dependencies=self._deps,
|
603
|
-
|
571
|
+
drop_input_cols = self._drop_input_cols,
|
604
572
|
expected_output_cols_type="float",
|
605
573
|
)
|
606
574
|
|
@@ -667,7 +635,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
667
635
|
transform_kwargs = dict(
|
668
636
|
session=dataset._session,
|
669
637
|
dependencies=self._deps,
|
670
|
-
|
638
|
+
drop_input_cols = self._drop_input_cols,
|
671
639
|
expected_output_cols_type="float",
|
672
640
|
)
|
673
641
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -730,7 +698,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
730
698
|
transform_kwargs = dict(
|
731
699
|
session=dataset._session,
|
732
700
|
dependencies=self._deps,
|
733
|
-
|
701
|
+
drop_input_cols = self._drop_input_cols,
|
734
702
|
expected_output_cols_type="float",
|
735
703
|
)
|
736
704
|
|
@@ -795,7 +763,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
795
763
|
transform_kwargs = dict(
|
796
764
|
session=dataset._session,
|
797
765
|
dependencies=self._deps,
|
798
|
-
|
766
|
+
drop_input_cols = self._drop_input_cols,
|
799
767
|
expected_output_cols_type="float",
|
800
768
|
)
|
801
769
|
|
@@ -851,13 +819,17 @@ class OneVsRestClassifier(BaseTransformer):
|
|
851
819
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
852
820
|
|
853
821
|
if isinstance(dataset, DataFrame):
|
822
|
+
self._deps = self._batch_inference_validate_snowpark(
|
823
|
+
dataset=dataset,
|
824
|
+
inference_method="score",
|
825
|
+
)
|
854
826
|
selected_cols = self._get_active_columns()
|
855
827
|
if len(selected_cols) > 0:
|
856
828
|
dataset = dataset.select(selected_cols)
|
857
829
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
858
830
|
transform_kwargs = dict(
|
859
831
|
session=dataset._session,
|
860
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
832
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
861
833
|
score_sproc_imports=['sklearn'],
|
862
834
|
)
|
863
835
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -931,9 +903,9 @@ class OneVsRestClassifier(BaseTransformer):
|
|
931
903
|
transform_kwargs = dict(
|
932
904
|
session = dataset._session,
|
933
905
|
dependencies = self._deps,
|
934
|
-
|
935
|
-
expected_output_cols_type
|
936
|
-
n_neighbors =
|
906
|
+
drop_input_cols = self._drop_input_cols,
|
907
|
+
expected_output_cols_type="array",
|
908
|
+
n_neighbors = n_neighbors,
|
937
909
|
return_distance = return_distance
|
938
910
|
)
|
939
911
|
elif isinstance(dataset, pd.DataFrame):
|