snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -386,18 +386,24 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
386
386
|
self._get_model_signatures(dataset)
|
387
387
|
return self
|
388
388
|
|
389
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
390
|
-
if self._drop_input_cols:
|
391
|
-
return []
|
392
|
-
else:
|
393
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
394
|
-
|
395
389
|
def _batch_inference_validate_snowpark(
|
396
390
|
self,
|
397
391
|
dataset: DataFrame,
|
398
392
|
inference_method: str,
|
399
393
|
) -> List[str]:
|
400
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
394
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
395
|
+
return the available package that exists in the snowflake anaconda channel
|
396
|
+
|
397
|
+
Args:
|
398
|
+
dataset: snowpark dataframe
|
399
|
+
inference_method: the inference method such as predict, score...
|
400
|
+
|
401
|
+
Raises:
|
402
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
403
|
+
SnowflakeMLException: If the session is None, raise error
|
404
|
+
|
405
|
+
Returns:
|
406
|
+
A list of available package that exists in the snowflake anaconda channel
|
401
407
|
"""
|
402
408
|
if not self._is_fitted:
|
403
409
|
raise exceptions.SnowflakeMLException(
|
@@ -471,7 +477,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
471
477
|
transform_kwargs = dict(
|
472
478
|
session = dataset._session,
|
473
479
|
dependencies = self._deps,
|
474
|
-
|
480
|
+
drop_input_cols = self._drop_input_cols,
|
475
481
|
expected_output_cols_type = expected_type_inferred,
|
476
482
|
)
|
477
483
|
|
@@ -531,16 +537,16 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
531
537
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
532
538
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
533
539
|
# each row containing a list of values.
|
534
|
-
expected_dtype = "
|
540
|
+
expected_dtype = "array"
|
535
541
|
|
536
542
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
537
543
|
if expected_dtype == "":
|
538
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
544
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
539
545
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
540
|
-
expected_dtype = "
|
541
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
546
|
+
expected_dtype = "array"
|
547
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
542
548
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
543
|
-
expected_dtype = "
|
549
|
+
expected_dtype = "array"
|
544
550
|
else:
|
545
551
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
546
552
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -558,7 +564,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
558
564
|
transform_kwargs = dict(
|
559
565
|
session = dataset._session,
|
560
566
|
dependencies = self._deps,
|
561
|
-
|
567
|
+
drop_input_cols = self._drop_input_cols,
|
562
568
|
expected_output_cols_type = expected_dtype,
|
563
569
|
)
|
564
570
|
|
@@ -611,7 +617,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
611
617
|
subproject=_SUBPROJECT,
|
612
618
|
)
|
613
619
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
614
|
-
|
620
|
+
drop_input_cols=self._drop_input_cols,
|
615
621
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
616
622
|
)
|
617
623
|
self._sklearn_object = fitted_estimator
|
@@ -629,44 +635,6 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
629
635
|
assert self._sklearn_object is not None
|
630
636
|
return self._sklearn_object.embedding_
|
631
637
|
|
632
|
-
|
633
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
634
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
635
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
636
|
-
"""
|
637
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
638
|
-
if output_cols:
|
639
|
-
output_cols = [
|
640
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
641
|
-
for c in output_cols
|
642
|
-
]
|
643
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
644
|
-
output_cols = [output_cols_prefix]
|
645
|
-
elif self._sklearn_object is not None:
|
646
|
-
classes = self._sklearn_object.classes_
|
647
|
-
if isinstance(classes, numpy.ndarray):
|
648
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
649
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
650
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
651
|
-
output_cols = []
|
652
|
-
for i, cl in enumerate(classes):
|
653
|
-
# For binary classification, there is only one output column for each class
|
654
|
-
# ndarray as the two classes are complementary.
|
655
|
-
if len(cl) == 2:
|
656
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
657
|
-
else:
|
658
|
-
output_cols.extend([
|
659
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
660
|
-
])
|
661
|
-
else:
|
662
|
-
output_cols = []
|
663
|
-
|
664
|
-
# Make sure column names are valid snowflake identifiers.
|
665
|
-
assert output_cols is not None # Make MyPy happy
|
666
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
667
|
-
|
668
|
-
return rv
|
669
|
-
|
670
638
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
671
639
|
@telemetry.send_api_usage_telemetry(
|
672
640
|
project=_PROJECT,
|
@@ -708,7 +676,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
708
676
|
transform_kwargs = dict(
|
709
677
|
session=dataset._session,
|
710
678
|
dependencies=self._deps,
|
711
|
-
|
679
|
+
drop_input_cols = self._drop_input_cols,
|
712
680
|
expected_output_cols_type="float",
|
713
681
|
)
|
714
682
|
|
@@ -775,7 +743,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
775
743
|
transform_kwargs = dict(
|
776
744
|
session=dataset._session,
|
777
745
|
dependencies=self._deps,
|
778
|
-
|
746
|
+
drop_input_cols = self._drop_input_cols,
|
779
747
|
expected_output_cols_type="float",
|
780
748
|
)
|
781
749
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -836,7 +804,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
836
804
|
transform_kwargs = dict(
|
837
805
|
session=dataset._session,
|
838
806
|
dependencies=self._deps,
|
839
|
-
|
807
|
+
drop_input_cols = self._drop_input_cols,
|
840
808
|
expected_output_cols_type="float",
|
841
809
|
)
|
842
810
|
|
@@ -903,7 +871,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
903
871
|
transform_kwargs = dict(
|
904
872
|
session=dataset._session,
|
905
873
|
dependencies=self._deps,
|
906
|
-
|
874
|
+
drop_input_cols = self._drop_input_cols,
|
907
875
|
expected_output_cols_type="float",
|
908
876
|
)
|
909
877
|
|
@@ -959,13 +927,17 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
959
927
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
960
928
|
|
961
929
|
if isinstance(dataset, DataFrame):
|
930
|
+
self._deps = self._batch_inference_validate_snowpark(
|
931
|
+
dataset=dataset,
|
932
|
+
inference_method="score",
|
933
|
+
)
|
962
934
|
selected_cols = self._get_active_columns()
|
963
935
|
if len(selected_cols) > 0:
|
964
936
|
dataset = dataset.select(selected_cols)
|
965
937
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
966
938
|
transform_kwargs = dict(
|
967
939
|
session=dataset._session,
|
968
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
940
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
969
941
|
score_sproc_imports=['sklearn'],
|
970
942
|
)
|
971
943
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1039,9 +1011,9 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
1039
1011
|
transform_kwargs = dict(
|
1040
1012
|
session = dataset._session,
|
1041
1013
|
dependencies = self._deps,
|
1042
|
-
|
1043
|
-
expected_output_cols_type
|
1044
|
-
n_neighbors =
|
1014
|
+
drop_input_cols = self._drop_input_cols,
|
1015
|
+
expected_output_cols_type="array",
|
1016
|
+
n_neighbors = n_neighbors,
|
1045
1017
|
return_distance = return_distance
|
1046
1018
|
)
|
1047
1019
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -359,18 +359,24 @@ class GaussianMixture(BaseTransformer):
|
|
359
359
|
self._get_model_signatures(dataset)
|
360
360
|
return self
|
361
361
|
|
362
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
363
|
-
if self._drop_input_cols:
|
364
|
-
return []
|
365
|
-
else:
|
366
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
367
|
-
|
368
362
|
def _batch_inference_validate_snowpark(
|
369
363
|
self,
|
370
364
|
dataset: DataFrame,
|
371
365
|
inference_method: str,
|
372
366
|
) -> List[str]:
|
373
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
367
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
368
|
+
return the available package that exists in the snowflake anaconda channel
|
369
|
+
|
370
|
+
Args:
|
371
|
+
dataset: snowpark dataframe
|
372
|
+
inference_method: the inference method such as predict, score...
|
373
|
+
|
374
|
+
Raises:
|
375
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
376
|
+
SnowflakeMLException: If the session is None, raise error
|
377
|
+
|
378
|
+
Returns:
|
379
|
+
A list of available package that exists in the snowflake anaconda channel
|
374
380
|
"""
|
375
381
|
if not self._is_fitted:
|
376
382
|
raise exceptions.SnowflakeMLException(
|
@@ -444,7 +450,7 @@ class GaussianMixture(BaseTransformer):
|
|
444
450
|
transform_kwargs = dict(
|
445
451
|
session = dataset._session,
|
446
452
|
dependencies = self._deps,
|
447
|
-
|
453
|
+
drop_input_cols = self._drop_input_cols,
|
448
454
|
expected_output_cols_type = expected_type_inferred,
|
449
455
|
)
|
450
456
|
|
@@ -504,16 +510,16 @@ class GaussianMixture(BaseTransformer):
|
|
504
510
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
505
511
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
506
512
|
# each row containing a list of values.
|
507
|
-
expected_dtype = "
|
513
|
+
expected_dtype = "array"
|
508
514
|
|
509
515
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
510
516
|
if expected_dtype == "":
|
511
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
517
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
512
518
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
513
|
-
expected_dtype = "
|
514
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
519
|
+
expected_dtype = "array"
|
520
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
515
521
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
516
|
-
expected_dtype = "
|
522
|
+
expected_dtype = "array"
|
517
523
|
else:
|
518
524
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
519
525
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -531,7 +537,7 @@ class GaussianMixture(BaseTransformer):
|
|
531
537
|
transform_kwargs = dict(
|
532
538
|
session = dataset._session,
|
533
539
|
dependencies = self._deps,
|
534
|
-
|
540
|
+
drop_input_cols = self._drop_input_cols,
|
535
541
|
expected_output_cols_type = expected_dtype,
|
536
542
|
)
|
537
543
|
|
@@ -584,7 +590,7 @@ class GaussianMixture(BaseTransformer):
|
|
584
590
|
subproject=_SUBPROJECT,
|
585
591
|
)
|
586
592
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
587
|
-
|
593
|
+
drop_input_cols=self._drop_input_cols,
|
588
594
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
589
595
|
)
|
590
596
|
self._sklearn_object = fitted_estimator
|
@@ -602,44 +608,6 @@ class GaussianMixture(BaseTransformer):
|
|
602
608
|
assert self._sklearn_object is not None
|
603
609
|
return self._sklearn_object.embedding_
|
604
610
|
|
605
|
-
|
606
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
607
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
608
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
609
|
-
"""
|
610
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
611
|
-
if output_cols:
|
612
|
-
output_cols = [
|
613
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
614
|
-
for c in output_cols
|
615
|
-
]
|
616
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
617
|
-
output_cols = [output_cols_prefix]
|
618
|
-
elif self._sklearn_object is not None:
|
619
|
-
classes = self._sklearn_object.classes_
|
620
|
-
if isinstance(classes, numpy.ndarray):
|
621
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
622
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
623
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
624
|
-
output_cols = []
|
625
|
-
for i, cl in enumerate(classes):
|
626
|
-
# For binary classification, there is only one output column for each class
|
627
|
-
# ndarray as the two classes are complementary.
|
628
|
-
if len(cl) == 2:
|
629
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
630
|
-
else:
|
631
|
-
output_cols.extend([
|
632
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
633
|
-
])
|
634
|
-
else:
|
635
|
-
output_cols = []
|
636
|
-
|
637
|
-
# Make sure column names are valid snowflake identifiers.
|
638
|
-
assert output_cols is not None # Make MyPy happy
|
639
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
640
|
-
|
641
|
-
return rv
|
642
|
-
|
643
611
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
644
612
|
@telemetry.send_api_usage_telemetry(
|
645
613
|
project=_PROJECT,
|
@@ -681,7 +649,7 @@ class GaussianMixture(BaseTransformer):
|
|
681
649
|
transform_kwargs = dict(
|
682
650
|
session=dataset._session,
|
683
651
|
dependencies=self._deps,
|
684
|
-
|
652
|
+
drop_input_cols = self._drop_input_cols,
|
685
653
|
expected_output_cols_type="float",
|
686
654
|
)
|
687
655
|
|
@@ -748,7 +716,7 @@ class GaussianMixture(BaseTransformer):
|
|
748
716
|
transform_kwargs = dict(
|
749
717
|
session=dataset._session,
|
750
718
|
dependencies=self._deps,
|
751
|
-
|
719
|
+
drop_input_cols = self._drop_input_cols,
|
752
720
|
expected_output_cols_type="float",
|
753
721
|
)
|
754
722
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -809,7 +777,7 @@ class GaussianMixture(BaseTransformer):
|
|
809
777
|
transform_kwargs = dict(
|
810
778
|
session=dataset._session,
|
811
779
|
dependencies=self._deps,
|
812
|
-
|
780
|
+
drop_input_cols = self._drop_input_cols,
|
813
781
|
expected_output_cols_type="float",
|
814
782
|
)
|
815
783
|
|
@@ -876,7 +844,7 @@ class GaussianMixture(BaseTransformer):
|
|
876
844
|
transform_kwargs = dict(
|
877
845
|
session=dataset._session,
|
878
846
|
dependencies=self._deps,
|
879
|
-
|
847
|
+
drop_input_cols = self._drop_input_cols,
|
880
848
|
expected_output_cols_type="float",
|
881
849
|
)
|
882
850
|
|
@@ -932,13 +900,17 @@ class GaussianMixture(BaseTransformer):
|
|
932
900
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
933
901
|
|
934
902
|
if isinstance(dataset, DataFrame):
|
903
|
+
self._deps = self._batch_inference_validate_snowpark(
|
904
|
+
dataset=dataset,
|
905
|
+
inference_method="score",
|
906
|
+
)
|
935
907
|
selected_cols = self._get_active_columns()
|
936
908
|
if len(selected_cols) > 0:
|
937
909
|
dataset = dataset.select(selected_cols)
|
938
910
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
939
911
|
transform_kwargs = dict(
|
940
912
|
session=dataset._session,
|
941
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
913
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
942
914
|
score_sproc_imports=['sklearn'],
|
943
915
|
)
|
944
916
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1012,9 +984,9 @@ class GaussianMixture(BaseTransformer):
|
|
1012
984
|
transform_kwargs = dict(
|
1013
985
|
session = dataset._session,
|
1014
986
|
dependencies = self._deps,
|
1015
|
-
|
1016
|
-
expected_output_cols_type
|
1017
|
-
n_neighbors =
|
987
|
+
drop_input_cols = self._drop_input_cols,
|
988
|
+
expected_output_cols_type="array",
|
989
|
+
n_neighbors = n_neighbors,
|
1018
990
|
return_distance = return_distance
|
1019
991
|
)
|
1020
992
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -216,6 +216,7 @@ class GridSearchCV(BaseTransformer):
|
|
216
216
|
expensive and is not strictly required to select the parameters that
|
217
217
|
yield the best generalization performance.
|
218
218
|
"""
|
219
|
+
|
219
220
|
_ENABLE_DISTRIBUTED = True
|
220
221
|
|
221
222
|
def __init__( # type: ignore[no-untyped-def]
|
@@ -332,14 +333,21 @@ class GridSearchCV(BaseTransformer):
|
|
332
333
|
self._get_model_signatures(dataset)
|
333
334
|
return self
|
334
335
|
|
335
|
-
def
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
336
|
+
def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
|
337
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
338
|
+
return the available package that exists in the snowflake anaconda channel
|
339
|
+
|
340
|
+
Args:
|
341
|
+
dataset: snowpark dataframe
|
342
|
+
inference_method: the inference method such as predict, score...
|
343
|
+
|
344
|
+
Raises:
|
345
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
346
|
+
SnowflakeMLException: If the session is None, raise error
|
340
347
|
|
341
|
-
|
342
|
-
|
348
|
+
Returns:
|
349
|
+
A list of available package that exists in the snowflake anaconda channel
|
350
|
+
"""
|
343
351
|
if not self._is_fitted:
|
344
352
|
raise exceptions.SnowflakeMLException(
|
345
353
|
error_code=error_codes.METHOD_NOT_ALLOWED,
|
@@ -355,7 +363,7 @@ class GridSearchCV(BaseTransformer):
|
|
355
363
|
original_exception=ValueError("Session must not specified for snowpark dataset."),
|
356
364
|
)
|
357
365
|
# Validate that key package version in user workspace are supported in snowflake conda channel
|
358
|
-
pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
366
|
+
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
359
367
|
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
360
368
|
)
|
361
369
|
|
@@ -391,7 +399,7 @@ class GridSearchCV(BaseTransformer):
|
|
391
399
|
expected_type_inferred = convert_sp_to_sf_type(
|
392
400
|
self.model_signatures["predict"].outputs[0].as_snowpark_type()
|
393
401
|
)
|
394
|
-
self._batch_inference_validate_snowpark(
|
402
|
+
self._deps = self._batch_inference_validate_snowpark(
|
395
403
|
dataset=dataset,
|
396
404
|
inference_method=inference_method,
|
397
405
|
)
|
@@ -402,8 +410,8 @@ class GridSearchCV(BaseTransformer):
|
|
402
410
|
|
403
411
|
transform_kwargs = dict(
|
404
412
|
session=dataset._session,
|
405
|
-
dependencies=self.
|
406
|
-
|
413
|
+
dependencies=self._deps,
|
414
|
+
drop_input_cols=self._drop_input_cols,
|
407
415
|
expected_output_cols_type=expected_type_inferred,
|
408
416
|
)
|
409
417
|
|
@@ -452,15 +460,15 @@ class GridSearchCV(BaseTransformer):
|
|
452
460
|
inference_method = "transform"
|
453
461
|
|
454
462
|
if isinstance(dataset, DataFrame):
|
455
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
463
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
456
464
|
assert isinstance(
|
457
465
|
dataset._session, Session
|
458
466
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
459
467
|
|
460
468
|
transform_kwargs = dict(
|
461
469
|
session=dataset._session,
|
462
|
-
dependencies=self.
|
463
|
-
|
470
|
+
dependencies=self._deps,
|
471
|
+
drop_input_cols=self._drop_input_cols,
|
464
472
|
)
|
465
473
|
|
466
474
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -482,36 +490,6 @@ class GridSearchCV(BaseTransformer):
|
|
482
490
|
)
|
483
491
|
return output_df
|
484
492
|
|
485
|
-
def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
|
486
|
-
"""Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
487
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
488
|
-
|
489
|
-
Args:
|
490
|
-
output_cols_prefix (str): prefix according to the function
|
491
|
-
|
492
|
-
Returns:
|
493
|
-
List[str]: output cols with prefix
|
494
|
-
"""
|
495
|
-
if getattr(self._sklearn_object, "classes_", None) is None:
|
496
|
-
return [output_cols_prefix]
|
497
|
-
|
498
|
-
assert self._sklearn_object is not None # keep mypy happy
|
499
|
-
classes = self._sklearn_object.classes_
|
500
|
-
if isinstance(classes, np.ndarray):
|
501
|
-
return [f"{output_cols_prefix}{c}" for c in classes.tolist()]
|
502
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], np.ndarray):
|
503
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
504
|
-
output_cols = []
|
505
|
-
for i, cl in enumerate(classes):
|
506
|
-
# For binary classification, there is only one output column for each class
|
507
|
-
# ndarray as the two classes are complementary.
|
508
|
-
if len(cl) == 2:
|
509
|
-
output_cols.append(f"{output_cols_prefix}_{i}_{cl[0]}")
|
510
|
-
else:
|
511
|
-
output_cols.extend([f"{output_cols_prefix}_{i}_{c}" for c in cl.tolist()])
|
512
|
-
return output_cols
|
513
|
-
return []
|
514
|
-
|
515
493
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
516
494
|
@telemetry.send_api_usage_telemetry(
|
517
495
|
project=_PROJECT,
|
@@ -541,14 +519,14 @@ class GridSearchCV(BaseTransformer):
|
|
541
519
|
inference_method = "predict_proba"
|
542
520
|
|
543
521
|
if isinstance(dataset, DataFrame):
|
544
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
522
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
545
523
|
assert isinstance(
|
546
524
|
dataset._session, Session
|
547
525
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
548
526
|
transform_kwargs = dict(
|
549
527
|
session=dataset._session,
|
550
|
-
dependencies=self.
|
551
|
-
|
528
|
+
dependencies=self._deps,
|
529
|
+
drop_input_cols=self._drop_input_cols,
|
552
530
|
expected_output_cols_type="float",
|
553
531
|
)
|
554
532
|
|
@@ -601,14 +579,14 @@ class GridSearchCV(BaseTransformer):
|
|
601
579
|
inference_method = "predict_log_proba"
|
602
580
|
|
603
581
|
if isinstance(dataset, DataFrame):
|
604
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
582
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
605
583
|
assert isinstance(
|
606
584
|
dataset._session, Session
|
607
585
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
608
586
|
transform_kwargs = dict(
|
609
587
|
session=dataset._session,
|
610
|
-
|
611
|
-
|
588
|
+
drop_input_cols=self._drop_input_cols,
|
589
|
+
dependencies=self._deps,
|
612
590
|
expected_output_cols_type="float",
|
613
591
|
)
|
614
592
|
|
@@ -661,14 +639,14 @@ class GridSearchCV(BaseTransformer):
|
|
661
639
|
inference_method = "decision_function"
|
662
640
|
|
663
641
|
if isinstance(dataset, DataFrame):
|
664
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
642
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
665
643
|
assert isinstance(
|
666
644
|
dataset._session, Session
|
667
645
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
668
646
|
transform_kwargs = dict(
|
669
647
|
session=dataset._session,
|
670
|
-
dependencies=self.
|
671
|
-
|
648
|
+
dependencies=self._deps,
|
649
|
+
drop_input_cols=self._drop_input_cols,
|
672
650
|
expected_output_cols_type="float",
|
673
651
|
)
|
674
652
|
|
@@ -722,14 +700,14 @@ class GridSearchCV(BaseTransformer):
|
|
722
700
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
723
701
|
|
724
702
|
if isinstance(dataset, DataFrame):
|
725
|
-
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
703
|
+
self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
726
704
|
assert isinstance(
|
727
705
|
dataset._session, Session
|
728
706
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
729
707
|
transform_kwargs = dict(
|
730
708
|
session=dataset._session,
|
731
|
-
dependencies=self.
|
732
|
-
|
709
|
+
dependencies=self._deps,
|
710
|
+
drop_input_cols=self._drop_input_cols,
|
733
711
|
expected_output_cols_type="float",
|
734
712
|
)
|
735
713
|
|
@@ -773,13 +751,17 @@ class GridSearchCV(BaseTransformer):
|
|
773
751
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
774
752
|
|
775
753
|
if isinstance(dataset, DataFrame):
|
754
|
+
self._deps = self._batch_inference_validate_snowpark(
|
755
|
+
dataset=dataset,
|
756
|
+
inference_method="score",
|
757
|
+
)
|
776
758
|
selected_cols = self._get_active_columns()
|
777
759
|
if len(selected_cols) > 0:
|
778
760
|
dataset = dataset.select(selected_cols)
|
779
761
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
780
762
|
transform_kwargs = dict(
|
781
763
|
session=dataset._session,
|
782
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
764
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
783
765
|
score_sproc_imports=["sklearn"],
|
784
766
|
)
|
785
767
|
elif isinstance(dataset, pd.DataFrame):
|