snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class Ridge(BaseTransformer):
|
70
64
|
r"""Linear least squares with l2 regularization
|
71
65
|
For more details on this class, see [sklearn.linear_model.Ridge]
|
@@ -358,20 +352,17 @@ class Ridge(BaseTransformer):
|
|
358
352
|
self,
|
359
353
|
dataset: DataFrame,
|
360
354
|
inference_method: str,
|
361
|
-
) ->
|
362
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
363
|
-
return the available package that exists in the snowflake anaconda channel
|
355
|
+
) -> None:
|
356
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
364
357
|
|
365
358
|
Args:
|
366
359
|
dataset: snowpark dataframe
|
367
360
|
inference_method: the inference method such as predict, score...
|
368
|
-
|
361
|
+
|
369
362
|
Raises:
|
370
363
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
371
364
|
SnowflakeMLException: If the session is None, raise error
|
372
365
|
|
373
|
-
Returns:
|
374
|
-
A list of available package that exists in the snowflake anaconda channel
|
375
366
|
"""
|
376
367
|
if not self._is_fitted:
|
377
368
|
raise exceptions.SnowflakeMLException(
|
@@ -389,9 +380,7 @@ class Ridge(BaseTransformer):
|
|
389
380
|
"Session must not specified for snowpark dataset."
|
390
381
|
),
|
391
382
|
)
|
392
|
-
|
393
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
394
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
383
|
+
|
395
384
|
|
396
385
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
397
386
|
@telemetry.send_api_usage_telemetry(
|
@@ -439,7 +428,8 @@ class Ridge(BaseTransformer):
|
|
439
428
|
|
440
429
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
441
430
|
|
442
|
-
self.
|
431
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
432
|
+
self._deps = self._get_dependencies()
|
443
433
|
assert isinstance(
|
444
434
|
dataset._session, Session
|
445
435
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -522,10 +512,8 @@ class Ridge(BaseTransformer):
|
|
522
512
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
523
513
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
524
514
|
|
525
|
-
self.
|
526
|
-
|
527
|
-
inference_method=inference_method,
|
528
|
-
)
|
515
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
516
|
+
self._deps = self._get_dependencies()
|
529
517
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
530
518
|
|
531
519
|
transform_kwargs = dict(
|
@@ -592,16 +580,40 @@ class Ridge(BaseTransformer):
|
|
592
580
|
self._is_fitted = True
|
593
581
|
return output_result
|
594
582
|
|
583
|
+
|
584
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
585
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
586
|
+
""" Method not supported for this class.
|
595
587
|
|
596
|
-
|
597
|
-
|
598
|
-
|
588
|
+
|
589
|
+
Raises:
|
590
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
591
|
+
|
592
|
+
Args:
|
593
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
594
|
+
Snowpark or Pandas DataFrame.
|
595
|
+
output_cols_prefix: Prefix for the response columns
|
599
596
|
Returns:
|
600
597
|
Transformed dataset.
|
601
598
|
"""
|
602
|
-
self.
|
603
|
-
|
604
|
-
|
599
|
+
self._infer_input_output_cols(dataset)
|
600
|
+
super()._check_dataset_type(dataset)
|
601
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
602
|
+
estimator=self._sklearn_object,
|
603
|
+
dataset=dataset,
|
604
|
+
input_cols=self.input_cols,
|
605
|
+
label_cols=self.label_cols,
|
606
|
+
sample_weight_col=self.sample_weight_col,
|
607
|
+
autogenerated=self._autogenerated,
|
608
|
+
subproject=_SUBPROJECT,
|
609
|
+
)
|
610
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
611
|
+
drop_input_cols=self._drop_input_cols,
|
612
|
+
expected_output_cols_list=self.output_cols,
|
613
|
+
)
|
614
|
+
self._sklearn_object = fitted_estimator
|
615
|
+
self._is_fitted = True
|
616
|
+
return output_result
|
605
617
|
|
606
618
|
|
607
619
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -692,10 +704,8 @@ class Ridge(BaseTransformer):
|
|
692
704
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
693
705
|
|
694
706
|
if isinstance(dataset, DataFrame):
|
695
|
-
self.
|
696
|
-
|
697
|
-
inference_method=inference_method,
|
698
|
-
)
|
707
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
708
|
+
self._deps = self._get_dependencies()
|
699
709
|
assert isinstance(
|
700
710
|
dataset._session, Session
|
701
711
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -760,10 +770,8 @@ class Ridge(BaseTransformer):
|
|
760
770
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
761
771
|
|
762
772
|
if isinstance(dataset, DataFrame):
|
763
|
-
self.
|
764
|
-
|
765
|
-
inference_method=inference_method,
|
766
|
-
)
|
773
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
774
|
+
self._deps = self._get_dependencies()
|
767
775
|
assert isinstance(
|
768
776
|
dataset._session, Session
|
769
777
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -825,10 +833,8 @@ class Ridge(BaseTransformer):
|
|
825
833
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
826
834
|
|
827
835
|
if isinstance(dataset, DataFrame):
|
828
|
-
self.
|
829
|
-
|
830
|
-
inference_method=inference_method,
|
831
|
-
)
|
836
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
837
|
+
self._deps = self._get_dependencies()
|
832
838
|
assert isinstance(
|
833
839
|
dataset._session, Session
|
834
840
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -894,10 +900,8 @@ class Ridge(BaseTransformer):
|
|
894
900
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
895
901
|
|
896
902
|
if isinstance(dataset, DataFrame):
|
897
|
-
self.
|
898
|
-
|
899
|
-
inference_method=inference_method,
|
900
|
-
)
|
903
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
904
|
+
self._deps = self._get_dependencies()
|
901
905
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
902
906
|
transform_kwargs = dict(
|
903
907
|
session=dataset._session,
|
@@ -961,17 +965,15 @@ class Ridge(BaseTransformer):
|
|
961
965
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
962
966
|
|
963
967
|
if isinstance(dataset, DataFrame):
|
964
|
-
self.
|
965
|
-
|
966
|
-
inference_method="score",
|
967
|
-
)
|
968
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
969
|
+
self._deps = self._get_dependencies()
|
968
970
|
selected_cols = self._get_active_columns()
|
969
971
|
if len(selected_cols) > 0:
|
970
972
|
dataset = dataset.select(selected_cols)
|
971
973
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
972
974
|
transform_kwargs = dict(
|
973
975
|
session=dataset._session,
|
974
|
-
dependencies=
|
976
|
+
dependencies=self._deps,
|
975
977
|
score_sproc_imports=['sklearn'],
|
976
978
|
)
|
977
979
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1036,11 +1038,8 @@ class Ridge(BaseTransformer):
|
|
1036
1038
|
|
1037
1039
|
if isinstance(dataset, DataFrame):
|
1038
1040
|
|
1039
|
-
self.
|
1040
|
-
|
1041
|
-
inference_method=inference_method,
|
1042
|
-
|
1043
|
-
)
|
1041
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1042
|
+
self._deps = self._get_dependencies()
|
1044
1043
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1045
1044
|
transform_kwargs = dict(
|
1046
1045
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RidgeClassifier(BaseTransformer):
|
70
64
|
r"""Classifier using Ridge regression
|
71
65
|
For more details on this class, see [sklearn.linear_model.RidgeClassifier]
|
@@ -358,20 +352,17 @@ class RidgeClassifier(BaseTransformer):
|
|
358
352
|
self,
|
359
353
|
dataset: DataFrame,
|
360
354
|
inference_method: str,
|
361
|
-
) ->
|
362
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
363
|
-
return the available package that exists in the snowflake anaconda channel
|
355
|
+
) -> None:
|
356
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
364
357
|
|
365
358
|
Args:
|
366
359
|
dataset: snowpark dataframe
|
367
360
|
inference_method: the inference method such as predict, score...
|
368
|
-
|
361
|
+
|
369
362
|
Raises:
|
370
363
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
371
364
|
SnowflakeMLException: If the session is None, raise error
|
372
365
|
|
373
|
-
Returns:
|
374
|
-
A list of available package that exists in the snowflake anaconda channel
|
375
366
|
"""
|
376
367
|
if not self._is_fitted:
|
377
368
|
raise exceptions.SnowflakeMLException(
|
@@ -389,9 +380,7 @@ class RidgeClassifier(BaseTransformer):
|
|
389
380
|
"Session must not specified for snowpark dataset."
|
390
381
|
),
|
391
382
|
)
|
392
|
-
|
393
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
394
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
383
|
+
|
395
384
|
|
396
385
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
397
386
|
@telemetry.send_api_usage_telemetry(
|
@@ -439,7 +428,8 @@ class RidgeClassifier(BaseTransformer):
|
|
439
428
|
|
440
429
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
441
430
|
|
442
|
-
self.
|
431
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
432
|
+
self._deps = self._get_dependencies()
|
443
433
|
assert isinstance(
|
444
434
|
dataset._session, Session
|
445
435
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -522,10 +512,8 @@ class RidgeClassifier(BaseTransformer):
|
|
522
512
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
523
513
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
524
514
|
|
525
|
-
self.
|
526
|
-
|
527
|
-
inference_method=inference_method,
|
528
|
-
)
|
515
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
516
|
+
self._deps = self._get_dependencies()
|
529
517
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
530
518
|
|
531
519
|
transform_kwargs = dict(
|
@@ -592,16 +580,40 @@ class RidgeClassifier(BaseTransformer):
|
|
592
580
|
self._is_fitted = True
|
593
581
|
return output_result
|
594
582
|
|
583
|
+
|
584
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
585
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
586
|
+
""" Method not supported for this class.
|
595
587
|
|
596
|
-
|
597
|
-
|
598
|
-
|
588
|
+
|
589
|
+
Raises:
|
590
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
591
|
+
|
592
|
+
Args:
|
593
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
594
|
+
Snowpark or Pandas DataFrame.
|
595
|
+
output_cols_prefix: Prefix for the response columns
|
599
596
|
Returns:
|
600
597
|
Transformed dataset.
|
601
598
|
"""
|
602
|
-
self.
|
603
|
-
|
604
|
-
|
599
|
+
self._infer_input_output_cols(dataset)
|
600
|
+
super()._check_dataset_type(dataset)
|
601
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
602
|
+
estimator=self._sklearn_object,
|
603
|
+
dataset=dataset,
|
604
|
+
input_cols=self.input_cols,
|
605
|
+
label_cols=self.label_cols,
|
606
|
+
sample_weight_col=self.sample_weight_col,
|
607
|
+
autogenerated=self._autogenerated,
|
608
|
+
subproject=_SUBPROJECT,
|
609
|
+
)
|
610
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
611
|
+
drop_input_cols=self._drop_input_cols,
|
612
|
+
expected_output_cols_list=self.output_cols,
|
613
|
+
)
|
614
|
+
self._sklearn_object = fitted_estimator
|
615
|
+
self._is_fitted = True
|
616
|
+
return output_result
|
605
617
|
|
606
618
|
|
607
619
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -692,10 +704,8 @@ class RidgeClassifier(BaseTransformer):
|
|
692
704
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
693
705
|
|
694
706
|
if isinstance(dataset, DataFrame):
|
695
|
-
self.
|
696
|
-
|
697
|
-
inference_method=inference_method,
|
698
|
-
)
|
707
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
708
|
+
self._deps = self._get_dependencies()
|
699
709
|
assert isinstance(
|
700
710
|
dataset._session, Session
|
701
711
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -760,10 +770,8 @@ class RidgeClassifier(BaseTransformer):
|
|
760
770
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
761
771
|
|
762
772
|
if isinstance(dataset, DataFrame):
|
763
|
-
self.
|
764
|
-
|
765
|
-
inference_method=inference_method,
|
766
|
-
)
|
773
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
774
|
+
self._deps = self._get_dependencies()
|
767
775
|
assert isinstance(
|
768
776
|
dataset._session, Session
|
769
777
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -827,10 +835,8 @@ class RidgeClassifier(BaseTransformer):
|
|
827
835
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
828
836
|
|
829
837
|
if isinstance(dataset, DataFrame):
|
830
|
-
self.
|
831
|
-
|
832
|
-
inference_method=inference_method,
|
833
|
-
)
|
838
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
839
|
+
self._deps = self._get_dependencies()
|
834
840
|
assert isinstance(
|
835
841
|
dataset._session, Session
|
836
842
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -896,10 +902,8 @@ class RidgeClassifier(BaseTransformer):
|
|
896
902
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
897
903
|
|
898
904
|
if isinstance(dataset, DataFrame):
|
899
|
-
self.
|
900
|
-
|
901
|
-
inference_method=inference_method,
|
902
|
-
)
|
905
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
906
|
+
self._deps = self._get_dependencies()
|
903
907
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
904
908
|
transform_kwargs = dict(
|
905
909
|
session=dataset._session,
|
@@ -963,17 +967,15 @@ class RidgeClassifier(BaseTransformer):
|
|
963
967
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
964
968
|
|
965
969
|
if isinstance(dataset, DataFrame):
|
966
|
-
self.
|
967
|
-
|
968
|
-
inference_method="score",
|
969
|
-
)
|
970
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
971
|
+
self._deps = self._get_dependencies()
|
970
972
|
selected_cols = self._get_active_columns()
|
971
973
|
if len(selected_cols) > 0:
|
972
974
|
dataset = dataset.select(selected_cols)
|
973
975
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
974
976
|
transform_kwargs = dict(
|
975
977
|
session=dataset._session,
|
976
|
-
dependencies=
|
978
|
+
dependencies=self._deps,
|
977
979
|
score_sproc_imports=['sklearn'],
|
978
980
|
)
|
979
981
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1038,11 +1040,8 @@ class RidgeClassifier(BaseTransformer):
|
|
1038
1040
|
|
1039
1041
|
if isinstance(dataset, DataFrame):
|
1040
1042
|
|
1041
|
-
self.
|
1042
|
-
|
1043
|
-
inference_method=inference_method,
|
1044
|
-
|
1045
|
-
)
|
1043
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1044
|
+
self._deps = self._get_dependencies()
|
1046
1045
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1047
1046
|
transform_kwargs = dict(
|
1048
1047
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RidgeClassifierCV(BaseTransformer):
|
70
64
|
r"""Ridge classifier with built-in cross-validation
|
71
65
|
For more details on this class, see [sklearn.linear_model.RidgeClassifierCV]
|
@@ -309,20 +303,17 @@ class RidgeClassifierCV(BaseTransformer):
|
|
309
303
|
self,
|
310
304
|
dataset: DataFrame,
|
311
305
|
inference_method: str,
|
312
|
-
) ->
|
313
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
314
|
-
return the available package that exists in the snowflake anaconda channel
|
306
|
+
) -> None:
|
307
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
315
308
|
|
316
309
|
Args:
|
317
310
|
dataset: snowpark dataframe
|
318
311
|
inference_method: the inference method such as predict, score...
|
319
|
-
|
312
|
+
|
320
313
|
Raises:
|
321
314
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
322
315
|
SnowflakeMLException: If the session is None, raise error
|
323
316
|
|
324
|
-
Returns:
|
325
|
-
A list of available package that exists in the snowflake anaconda channel
|
326
317
|
"""
|
327
318
|
if not self._is_fitted:
|
328
319
|
raise exceptions.SnowflakeMLException(
|
@@ -340,9 +331,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
340
331
|
"Session must not specified for snowpark dataset."
|
341
332
|
),
|
342
333
|
)
|
343
|
-
|
344
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
345
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
334
|
+
|
346
335
|
|
347
336
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
348
337
|
@telemetry.send_api_usage_telemetry(
|
@@ -390,7 +379,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
390
379
|
|
391
380
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
392
381
|
|
393
|
-
self.
|
382
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
383
|
+
self._deps = self._get_dependencies()
|
394
384
|
assert isinstance(
|
395
385
|
dataset._session, Session
|
396
386
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -473,10 +463,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
473
463
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
474
464
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
475
465
|
|
476
|
-
self.
|
477
|
-
|
478
|
-
inference_method=inference_method,
|
479
|
-
)
|
466
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
467
|
+
self._deps = self._get_dependencies()
|
480
468
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
481
469
|
|
482
470
|
transform_kwargs = dict(
|
@@ -543,16 +531,40 @@ class RidgeClassifierCV(BaseTransformer):
|
|
543
531
|
self._is_fitted = True
|
544
532
|
return output_result
|
545
533
|
|
534
|
+
|
535
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
536
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
537
|
+
""" Method not supported for this class.
|
546
538
|
|
547
|
-
|
548
|
-
|
549
|
-
|
539
|
+
|
540
|
+
Raises:
|
541
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
542
|
+
|
543
|
+
Args:
|
544
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
545
|
+
Snowpark or Pandas DataFrame.
|
546
|
+
output_cols_prefix: Prefix for the response columns
|
550
547
|
Returns:
|
551
548
|
Transformed dataset.
|
552
549
|
"""
|
553
|
-
self.
|
554
|
-
|
555
|
-
|
550
|
+
self._infer_input_output_cols(dataset)
|
551
|
+
super()._check_dataset_type(dataset)
|
552
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
553
|
+
estimator=self._sklearn_object,
|
554
|
+
dataset=dataset,
|
555
|
+
input_cols=self.input_cols,
|
556
|
+
label_cols=self.label_cols,
|
557
|
+
sample_weight_col=self.sample_weight_col,
|
558
|
+
autogenerated=self._autogenerated,
|
559
|
+
subproject=_SUBPROJECT,
|
560
|
+
)
|
561
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
562
|
+
drop_input_cols=self._drop_input_cols,
|
563
|
+
expected_output_cols_list=self.output_cols,
|
564
|
+
)
|
565
|
+
self._sklearn_object = fitted_estimator
|
566
|
+
self._is_fitted = True
|
567
|
+
return output_result
|
556
568
|
|
557
569
|
|
558
570
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -643,10 +655,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
643
655
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
644
656
|
|
645
657
|
if isinstance(dataset, DataFrame):
|
646
|
-
self.
|
647
|
-
|
648
|
-
inference_method=inference_method,
|
649
|
-
)
|
658
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
659
|
+
self._deps = self._get_dependencies()
|
650
660
|
assert isinstance(
|
651
661
|
dataset._session, Session
|
652
662
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -711,10 +721,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
711
721
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
712
722
|
|
713
723
|
if isinstance(dataset, DataFrame):
|
714
|
-
self.
|
715
|
-
|
716
|
-
inference_method=inference_method,
|
717
|
-
)
|
724
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
725
|
+
self._deps = self._get_dependencies()
|
718
726
|
assert isinstance(
|
719
727
|
dataset._session, Session
|
720
728
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -778,10 +786,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
778
786
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
779
787
|
|
780
788
|
if isinstance(dataset, DataFrame):
|
781
|
-
self.
|
782
|
-
|
783
|
-
inference_method=inference_method,
|
784
|
-
)
|
789
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
790
|
+
self._deps = self._get_dependencies()
|
785
791
|
assert isinstance(
|
786
792
|
dataset._session, Session
|
787
793
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -847,10 +853,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
847
853
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
848
854
|
|
849
855
|
if isinstance(dataset, DataFrame):
|
850
|
-
self.
|
851
|
-
|
852
|
-
inference_method=inference_method,
|
853
|
-
)
|
856
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
857
|
+
self._deps = self._get_dependencies()
|
854
858
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
855
859
|
transform_kwargs = dict(
|
856
860
|
session=dataset._session,
|
@@ -914,17 +918,15 @@ class RidgeClassifierCV(BaseTransformer):
|
|
914
918
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
915
919
|
|
916
920
|
if isinstance(dataset, DataFrame):
|
917
|
-
self.
|
918
|
-
|
919
|
-
inference_method="score",
|
920
|
-
)
|
921
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
922
|
+
self._deps = self._get_dependencies()
|
921
923
|
selected_cols = self._get_active_columns()
|
922
924
|
if len(selected_cols) > 0:
|
923
925
|
dataset = dataset.select(selected_cols)
|
924
926
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
925
927
|
transform_kwargs = dict(
|
926
928
|
session=dataset._session,
|
927
|
-
dependencies=
|
929
|
+
dependencies=self._deps,
|
928
930
|
score_sproc_imports=['sklearn'],
|
929
931
|
)
|
930
932
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -989,11 +991,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
989
991
|
|
990
992
|
if isinstance(dataset, DataFrame):
|
991
993
|
|
992
|
-
self.
|
993
|
-
|
994
|
-
inference_method=inference_method,
|
995
|
-
|
996
|
-
)
|
994
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
995
|
+
self._deps = self._get_dependencies()
|
997
996
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
998
997
|
transform_kwargs = dict(
|
999
998
|
session = dataset._session,
|