snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -334,9 +334,12 @@ class GridSearchCV(BaseTransformer):
|
|
334
334
|
self._generate_model_signatures(dataset)
|
335
335
|
return self
|
336
336
|
|
337
|
-
def _batch_inference_validate_snowpark(
|
338
|
-
|
339
|
-
|
337
|
+
def _batch_inference_validate_snowpark(
|
338
|
+
self,
|
339
|
+
dataset: DataFrame,
|
340
|
+
inference_method: str,
|
341
|
+
) -> None:
|
342
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
340
343
|
|
341
344
|
Args:
|
342
345
|
dataset: snowpark dataframe
|
@@ -346,8 +349,6 @@ class GridSearchCV(BaseTransformer):
|
|
346
349
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
347
350
|
SnowflakeMLException: If the session is None, raise error
|
348
351
|
|
349
|
-
Returns:
|
350
|
-
A list of available package that exists in the snowflake anaconda channel
|
351
352
|
"""
|
352
353
|
if not self._is_fitted:
|
353
354
|
raise exceptions.SnowflakeMLException(
|
@@ -363,10 +364,6 @@ class GridSearchCV(BaseTransformer):
|
|
363
364
|
error_code=error_codes.NOT_FOUND,
|
364
365
|
original_exception=ValueError("Session must not specified for snowpark dataset."),
|
365
366
|
)
|
366
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
367
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
368
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
369
|
-
)
|
370
367
|
|
371
368
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
372
369
|
@telemetry.send_api_usage_telemetry(
|
@@ -415,10 +412,8 @@ class GridSearchCV(BaseTransformer):
|
|
415
412
|
)
|
416
413
|
|
417
414
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
418
|
-
self.
|
419
|
-
|
420
|
-
inference_method=inference_method,
|
421
|
-
)
|
415
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
416
|
+
self._deps = self._get_dependencies()
|
422
417
|
|
423
418
|
assert isinstance(
|
424
419
|
dataset._session, Session
|
@@ -476,7 +471,8 @@ class GridSearchCV(BaseTransformer):
|
|
476
471
|
inference_method = "transform"
|
477
472
|
|
478
473
|
if isinstance(dataset, DataFrame):
|
479
|
-
self.
|
474
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
475
|
+
self._deps = self._get_dependencies()
|
480
476
|
assert isinstance(
|
481
477
|
dataset._session, Session
|
482
478
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -535,7 +531,8 @@ class GridSearchCV(BaseTransformer):
|
|
535
531
|
inference_method = "predict_proba"
|
536
532
|
|
537
533
|
if isinstance(dataset, DataFrame):
|
538
|
-
self.
|
534
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
535
|
+
self._deps = self._get_dependencies()
|
539
536
|
assert isinstance(
|
540
537
|
dataset._session, Session
|
541
538
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -595,7 +592,8 @@ class GridSearchCV(BaseTransformer):
|
|
595
592
|
inference_method = "predict_log_proba"
|
596
593
|
|
597
594
|
if isinstance(dataset, DataFrame):
|
598
|
-
self.
|
595
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
596
|
+
self._deps = self._get_dependencies()
|
599
597
|
assert isinstance(
|
600
598
|
dataset._session, Session
|
601
599
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -655,7 +653,8 @@ class GridSearchCV(BaseTransformer):
|
|
655
653
|
inference_method = "decision_function"
|
656
654
|
|
657
655
|
if isinstance(dataset, DataFrame):
|
658
|
-
self.
|
656
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
657
|
+
self._deps = self._get_dependencies()
|
659
658
|
assert isinstance(
|
660
659
|
dataset._session, Session
|
661
660
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -716,7 +715,8 @@ class GridSearchCV(BaseTransformer):
|
|
716
715
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
717
716
|
|
718
717
|
if isinstance(dataset, DataFrame):
|
719
|
-
self.
|
718
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
719
|
+
self._deps = self._get_dependencies()
|
720
720
|
assert isinstance(
|
721
721
|
dataset._session, Session
|
722
722
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -767,17 +767,15 @@ class GridSearchCV(BaseTransformer):
|
|
767
767
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
768
768
|
|
769
769
|
if isinstance(dataset, DataFrame):
|
770
|
-
self.
|
771
|
-
|
772
|
-
inference_method="score",
|
773
|
-
)
|
770
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
771
|
+
self._deps = self._get_dependencies()
|
774
772
|
selected_cols = self._get_active_columns()
|
775
773
|
if len(selected_cols) > 0:
|
776
774
|
dataset = dataset.select(selected_cols)
|
777
775
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
778
776
|
transform_kwargs = dict(
|
779
777
|
session=dataset._session,
|
780
|
-
dependencies=
|
778
|
+
dependencies=self._deps,
|
781
779
|
score_sproc_imports=["sklearn"],
|
782
780
|
)
|
783
781
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -347,8 +347,22 @@ class RandomizedSearchCV(BaseTransformer):
|
|
347
347
|
self._generate_model_signatures(dataset)
|
348
348
|
return self
|
349
349
|
|
350
|
-
def _batch_inference_validate_snowpark(
|
351
|
-
|
350
|
+
def _batch_inference_validate_snowpark(
|
351
|
+
self,
|
352
|
+
dataset: DataFrame,
|
353
|
+
inference_method: str,
|
354
|
+
) -> None:
|
355
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
dataset: snowpark dataframe
|
359
|
+
inference_method: the inference method such as predict, score...
|
360
|
+
|
361
|
+
Raises:
|
362
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
363
|
+
SnowflakeMLException: If the session is None, raise error
|
364
|
+
|
365
|
+
"""
|
352
366
|
if not self._is_fitted:
|
353
367
|
raise exceptions.SnowflakeMLException(
|
354
368
|
error_code=error_codes.METHOD_NOT_ALLOWED,
|
@@ -363,10 +377,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
363
377
|
error_code=error_codes.NOT_FOUND,
|
364
378
|
original_exception=ValueError("Session must not specified for snowpark dataset."),
|
365
379
|
)
|
366
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
367
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
368
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
369
|
-
)
|
370
380
|
|
371
381
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
372
382
|
@telemetry.send_api_usage_telemetry(
|
@@ -414,10 +424,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
414
424
|
)
|
415
425
|
|
416
426
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
417
|
-
self.
|
418
|
-
|
419
|
-
|
420
|
-
)
|
427
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
428
|
+
self._deps = self._get_dependencies()
|
429
|
+
|
421
430
|
assert isinstance(
|
422
431
|
dataset._session, Session
|
423
432
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -473,7 +482,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
473
482
|
inference_method = "transform"
|
474
483
|
|
475
484
|
if isinstance(dataset, DataFrame):
|
476
|
-
self.
|
485
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
486
|
+
self._deps = self._get_dependencies()
|
487
|
+
|
477
488
|
assert isinstance(
|
478
489
|
dataset._session, Session
|
479
490
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -531,7 +542,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
531
542
|
inference_method = "predict_proba"
|
532
543
|
|
533
544
|
if isinstance(dataset, DataFrame):
|
534
|
-
self.
|
545
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
546
|
+
self._deps = self._get_dependencies()
|
547
|
+
|
535
548
|
assert isinstance(
|
536
549
|
dataset._session, Session
|
537
550
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -591,7 +604,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
591
604
|
inference_method = "predict_log_proba"
|
592
605
|
|
593
606
|
if isinstance(dataset, DataFrame):
|
594
|
-
self.
|
607
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
608
|
+
self._deps = self._get_dependencies()
|
609
|
+
|
595
610
|
assert isinstance(
|
596
611
|
dataset._session, Session
|
597
612
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -650,7 +665,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
650
665
|
inference_method = "decision_function"
|
651
666
|
|
652
667
|
if isinstance(dataset, DataFrame):
|
653
|
-
self.
|
668
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
669
|
+
self._deps = self._get_dependencies()
|
670
|
+
|
654
671
|
assert isinstance(
|
655
672
|
dataset._session, Session
|
656
673
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -711,7 +728,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
711
728
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
712
729
|
|
713
730
|
if isinstance(dataset, DataFrame):
|
714
|
-
self.
|
731
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
732
|
+
self._deps = self._get_dependencies()
|
733
|
+
|
715
734
|
assert isinstance(
|
716
735
|
dataset._session, Session
|
717
736
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -761,10 +780,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
761
780
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
762
781
|
|
763
782
|
if isinstance(dataset, DataFrame):
|
764
|
-
self.
|
765
|
-
|
766
|
-
|
767
|
-
)
|
783
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
784
|
+
self._deps = self._get_dependencies()
|
785
|
+
|
768
786
|
selected_cols = self._get_active_columns()
|
769
787
|
if len(selected_cols) > 0:
|
770
788
|
dataset = dataset.select(selected_cols)
|
@@ -772,7 +790,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
772
790
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
773
791
|
transform_kwargs = dict(
|
774
792
|
session=dataset._session,
|
775
|
-
dependencies=
|
793
|
+
dependencies=self._deps,
|
776
794
|
score_sproc_imports=["sklearn"],
|
777
795
|
)
|
778
796
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class OneVsOneClassifier(BaseTransformer):
|
70
64
|
r"""One-vs-one multiclass strategy
|
71
65
|
For more details on this class, see [sklearn.multiclass.OneVsOneClassifier]
|
@@ -271,20 +265,17 @@ class OneVsOneClassifier(BaseTransformer):
|
|
271
265
|
self,
|
272
266
|
dataset: DataFrame,
|
273
267
|
inference_method: str,
|
274
|
-
) ->
|
275
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
276
|
-
return the available package that exists in the snowflake anaconda channel
|
268
|
+
) -> None:
|
269
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
277
270
|
|
278
271
|
Args:
|
279
272
|
dataset: snowpark dataframe
|
280
273
|
inference_method: the inference method such as predict, score...
|
281
|
-
|
274
|
+
|
282
275
|
Raises:
|
283
276
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
284
277
|
SnowflakeMLException: If the session is None, raise error
|
285
278
|
|
286
|
-
Returns:
|
287
|
-
A list of available package that exists in the snowflake anaconda channel
|
288
279
|
"""
|
289
280
|
if not self._is_fitted:
|
290
281
|
raise exceptions.SnowflakeMLException(
|
@@ -302,9 +293,7 @@ class OneVsOneClassifier(BaseTransformer):
|
|
302
293
|
"Session must not specified for snowpark dataset."
|
303
294
|
),
|
304
295
|
)
|
305
|
-
|
306
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
307
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
296
|
+
|
308
297
|
|
309
298
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
310
299
|
@telemetry.send_api_usage_telemetry(
|
@@ -352,7 +341,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
352
341
|
|
353
342
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
354
343
|
|
355
|
-
self.
|
344
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
345
|
+
self._deps = self._get_dependencies()
|
356
346
|
assert isinstance(
|
357
347
|
dataset._session, Session
|
358
348
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -435,10 +425,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
435
425
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
436
426
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
437
427
|
|
438
|
-
self.
|
439
|
-
|
440
|
-
inference_method=inference_method,
|
441
|
-
)
|
428
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
429
|
+
self._deps = self._get_dependencies()
|
442
430
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
443
431
|
|
444
432
|
transform_kwargs = dict(
|
@@ -505,16 +493,40 @@ class OneVsOneClassifier(BaseTransformer):
|
|
505
493
|
self._is_fitted = True
|
506
494
|
return output_result
|
507
495
|
|
496
|
+
|
497
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
498
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
499
|
+
""" Method not supported for this class.
|
508
500
|
|
509
|
-
|
510
|
-
|
511
|
-
|
501
|
+
|
502
|
+
Raises:
|
503
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
504
|
+
|
505
|
+
Args:
|
506
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
507
|
+
Snowpark or Pandas DataFrame.
|
508
|
+
output_cols_prefix: Prefix for the response columns
|
512
509
|
Returns:
|
513
510
|
Transformed dataset.
|
514
511
|
"""
|
515
|
-
self.
|
516
|
-
|
517
|
-
|
512
|
+
self._infer_input_output_cols(dataset)
|
513
|
+
super()._check_dataset_type(dataset)
|
514
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
515
|
+
estimator=self._sklearn_object,
|
516
|
+
dataset=dataset,
|
517
|
+
input_cols=self.input_cols,
|
518
|
+
label_cols=self.label_cols,
|
519
|
+
sample_weight_col=self.sample_weight_col,
|
520
|
+
autogenerated=self._autogenerated,
|
521
|
+
subproject=_SUBPROJECT,
|
522
|
+
)
|
523
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
524
|
+
drop_input_cols=self._drop_input_cols,
|
525
|
+
expected_output_cols_list=self.output_cols,
|
526
|
+
)
|
527
|
+
self._sklearn_object = fitted_estimator
|
528
|
+
self._is_fitted = True
|
529
|
+
return output_result
|
518
530
|
|
519
531
|
|
520
532
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -605,10 +617,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
605
617
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
606
618
|
|
607
619
|
if isinstance(dataset, DataFrame):
|
608
|
-
self.
|
609
|
-
|
610
|
-
inference_method=inference_method,
|
611
|
-
)
|
620
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
621
|
+
self._deps = self._get_dependencies()
|
612
622
|
assert isinstance(
|
613
623
|
dataset._session, Session
|
614
624
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -673,10 +683,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
673
683
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
674
684
|
|
675
685
|
if isinstance(dataset, DataFrame):
|
676
|
-
self.
|
677
|
-
|
678
|
-
inference_method=inference_method,
|
679
|
-
)
|
686
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
687
|
+
self._deps = self._get_dependencies()
|
680
688
|
assert isinstance(
|
681
689
|
dataset._session, Session
|
682
690
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -740,10 +748,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
740
748
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
741
749
|
|
742
750
|
if isinstance(dataset, DataFrame):
|
743
|
-
self.
|
744
|
-
|
745
|
-
inference_method=inference_method,
|
746
|
-
)
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
747
753
|
assert isinstance(
|
748
754
|
dataset._session, Session
|
749
755
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -809,10 +815,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
809
815
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
810
816
|
|
811
817
|
if isinstance(dataset, DataFrame):
|
812
|
-
self.
|
813
|
-
|
814
|
-
inference_method=inference_method,
|
815
|
-
)
|
818
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
819
|
+
self._deps = self._get_dependencies()
|
816
820
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
817
821
|
transform_kwargs = dict(
|
818
822
|
session=dataset._session,
|
@@ -876,17 +880,15 @@ class OneVsOneClassifier(BaseTransformer):
|
|
876
880
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
877
881
|
|
878
882
|
if isinstance(dataset, DataFrame):
|
879
|
-
self.
|
880
|
-
|
881
|
-
inference_method="score",
|
882
|
-
)
|
883
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
884
|
+
self._deps = self._get_dependencies()
|
883
885
|
selected_cols = self._get_active_columns()
|
884
886
|
if len(selected_cols) > 0:
|
885
887
|
dataset = dataset.select(selected_cols)
|
886
888
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
887
889
|
transform_kwargs = dict(
|
888
890
|
session=dataset._session,
|
889
|
-
dependencies=
|
891
|
+
dependencies=self._deps,
|
890
892
|
score_sproc_imports=['sklearn'],
|
891
893
|
)
|
892
894
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -951,11 +953,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
951
953
|
|
952
954
|
if isinstance(dataset, DataFrame):
|
953
955
|
|
954
|
-
self.
|
955
|
-
|
956
|
-
inference_method=inference_method,
|
957
|
-
|
958
|
-
)
|
956
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
957
|
+
self._deps = self._get_dependencies()
|
959
958
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
960
959
|
transform_kwargs = dict(
|
961
960
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("skl
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class OneVsRestClassifier(BaseTransformer):
|
70
64
|
r"""One-vs-the-rest (OvR) multiclass strategy
|
71
65
|
For more details on this class, see [sklearn.multiclass.OneVsRestClassifier]
|
@@ -280,20 +274,17 @@ class OneVsRestClassifier(BaseTransformer):
|
|
280
274
|
self,
|
281
275
|
dataset: DataFrame,
|
282
276
|
inference_method: str,
|
283
|
-
) ->
|
284
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
285
|
-
return the available package that exists in the snowflake anaconda channel
|
277
|
+
) -> None:
|
278
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
286
279
|
|
287
280
|
Args:
|
288
281
|
dataset: snowpark dataframe
|
289
282
|
inference_method: the inference method such as predict, score...
|
290
|
-
|
283
|
+
|
291
284
|
Raises:
|
292
285
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
293
286
|
SnowflakeMLException: If the session is None, raise error
|
294
287
|
|
295
|
-
Returns:
|
296
|
-
A list of available package that exists in the snowflake anaconda channel
|
297
288
|
"""
|
298
289
|
if not self._is_fitted:
|
299
290
|
raise exceptions.SnowflakeMLException(
|
@@ -311,9 +302,7 @@ class OneVsRestClassifier(BaseTransformer):
|
|
311
302
|
"Session must not specified for snowpark dataset."
|
312
303
|
),
|
313
304
|
)
|
314
|
-
|
315
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
316
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
305
|
+
|
317
306
|
|
318
307
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
319
308
|
@telemetry.send_api_usage_telemetry(
|
@@ -361,7 +350,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
361
350
|
|
362
351
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
363
352
|
|
364
|
-
self.
|
353
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
354
|
+
self._deps = self._get_dependencies()
|
365
355
|
assert isinstance(
|
366
356
|
dataset._session, Session
|
367
357
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -444,10 +434,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
444
434
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
445
435
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
446
436
|
|
447
|
-
self.
|
448
|
-
|
449
|
-
inference_method=inference_method,
|
450
|
-
)
|
437
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
438
|
+
self._deps = self._get_dependencies()
|
451
439
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
452
440
|
|
453
441
|
transform_kwargs = dict(
|
@@ -514,16 +502,40 @@ class OneVsRestClassifier(BaseTransformer):
|
|
514
502
|
self._is_fitted = True
|
515
503
|
return output_result
|
516
504
|
|
505
|
+
|
506
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
507
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
508
|
+
""" Method not supported for this class.
|
517
509
|
|
518
|
-
|
519
|
-
|
520
|
-
|
510
|
+
|
511
|
+
Raises:
|
512
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
513
|
+
|
514
|
+
Args:
|
515
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
516
|
+
Snowpark or Pandas DataFrame.
|
517
|
+
output_cols_prefix: Prefix for the response columns
|
521
518
|
Returns:
|
522
519
|
Transformed dataset.
|
523
520
|
"""
|
524
|
-
self.
|
525
|
-
|
526
|
-
|
521
|
+
self._infer_input_output_cols(dataset)
|
522
|
+
super()._check_dataset_type(dataset)
|
523
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
524
|
+
estimator=self._sklearn_object,
|
525
|
+
dataset=dataset,
|
526
|
+
input_cols=self.input_cols,
|
527
|
+
label_cols=self.label_cols,
|
528
|
+
sample_weight_col=self.sample_weight_col,
|
529
|
+
autogenerated=self._autogenerated,
|
530
|
+
subproject=_SUBPROJECT,
|
531
|
+
)
|
532
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
533
|
+
drop_input_cols=self._drop_input_cols,
|
534
|
+
expected_output_cols_list=self.output_cols,
|
535
|
+
)
|
536
|
+
self._sklearn_object = fitted_estimator
|
537
|
+
self._is_fitted = True
|
538
|
+
return output_result
|
527
539
|
|
528
540
|
|
529
541
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -616,10 +628,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
616
628
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
617
629
|
|
618
630
|
if isinstance(dataset, DataFrame):
|
619
|
-
self.
|
620
|
-
|
621
|
-
inference_method=inference_method,
|
622
|
-
)
|
631
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
632
|
+
self._deps = self._get_dependencies()
|
623
633
|
assert isinstance(
|
624
634
|
dataset._session, Session
|
625
635
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -686,10 +696,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
686
696
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
687
697
|
|
688
698
|
if isinstance(dataset, DataFrame):
|
689
|
-
self.
|
690
|
-
|
691
|
-
inference_method=inference_method,
|
692
|
-
)
|
699
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
700
|
+
self._deps = self._get_dependencies()
|
693
701
|
assert isinstance(
|
694
702
|
dataset._session, Session
|
695
703
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -753,10 +761,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
753
761
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
754
762
|
|
755
763
|
if isinstance(dataset, DataFrame):
|
756
|
-
self.
|
757
|
-
|
758
|
-
inference_method=inference_method,
|
759
|
-
)
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
760
766
|
assert isinstance(
|
761
767
|
dataset._session, Session
|
762
768
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -822,10 +828,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
822
828
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
823
829
|
|
824
830
|
if isinstance(dataset, DataFrame):
|
825
|
-
self.
|
826
|
-
|
827
|
-
inference_method=inference_method,
|
828
|
-
)
|
831
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
832
|
+
self._deps = self._get_dependencies()
|
829
833
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
830
834
|
transform_kwargs = dict(
|
831
835
|
session=dataset._session,
|
@@ -889,17 +893,15 @@ class OneVsRestClassifier(BaseTransformer):
|
|
889
893
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
890
894
|
|
891
895
|
if isinstance(dataset, DataFrame):
|
892
|
-
self.
|
893
|
-
|
894
|
-
inference_method="score",
|
895
|
-
)
|
896
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
897
|
+
self._deps = self._get_dependencies()
|
896
898
|
selected_cols = self._get_active_columns()
|
897
899
|
if len(selected_cols) > 0:
|
898
900
|
dataset = dataset.select(selected_cols)
|
899
901
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
900
902
|
transform_kwargs = dict(
|
901
903
|
session=dataset._session,
|
902
|
-
dependencies=
|
904
|
+
dependencies=self._deps,
|
903
905
|
score_sproc_imports=['sklearn'],
|
904
906
|
)
|
905
907
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -964,11 +966,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
964
966
|
|
965
967
|
if isinstance(dataset, DataFrame):
|
966
968
|
|
967
|
-
self.
|
968
|
-
|
969
|
-
inference_method=inference_method,
|
970
|
-
|
971
|
-
)
|
969
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
970
|
+
self._deps = self._get_dependencies()
|
972
971
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
973
972
|
transform_kwargs = dict(
|
974
973
|
session = dataset._session,
|