snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "lightgbm".replace("sklearn.", ""
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LGBMRegressor(BaseTransformer):
|
70
64
|
r"""LightGBM regressor
|
71
65
|
For more details on this class, see [lightgbm.LGBMRegressor]
|
@@ -294,20 +288,17 @@ class LGBMRegressor(BaseTransformer):
|
|
294
288
|
self,
|
295
289
|
dataset: DataFrame,
|
296
290
|
inference_method: str,
|
297
|
-
) ->
|
298
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
299
|
-
return the available package that exists in the snowflake anaconda channel
|
291
|
+
) -> None:
|
292
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
300
293
|
|
301
294
|
Args:
|
302
295
|
dataset: snowpark dataframe
|
303
296
|
inference_method: the inference method such as predict, score...
|
304
|
-
|
297
|
+
|
305
298
|
Raises:
|
306
299
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
307
300
|
SnowflakeMLException: If the session is None, raise error
|
308
301
|
|
309
|
-
Returns:
|
310
|
-
A list of available package that exists in the snowflake anaconda channel
|
311
302
|
"""
|
312
303
|
if not self._is_fitted:
|
313
304
|
raise exceptions.SnowflakeMLException(
|
@@ -325,9 +316,7 @@ class LGBMRegressor(BaseTransformer):
|
|
325
316
|
"Session must not specified for snowpark dataset."
|
326
317
|
),
|
327
318
|
)
|
328
|
-
|
329
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
330
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
319
|
+
|
331
320
|
|
332
321
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
333
322
|
@telemetry.send_api_usage_telemetry(
|
@@ -375,7 +364,8 @@ class LGBMRegressor(BaseTransformer):
|
|
375
364
|
|
376
365
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
377
366
|
|
378
|
-
self.
|
367
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
368
|
+
self._deps = self._get_dependencies()
|
379
369
|
assert isinstance(
|
380
370
|
dataset._session, Session
|
381
371
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -458,10 +448,8 @@ class LGBMRegressor(BaseTransformer):
|
|
458
448
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
459
449
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
460
450
|
|
461
|
-
self.
|
462
|
-
|
463
|
-
inference_method=inference_method,
|
464
|
-
)
|
451
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
452
|
+
self._deps = self._get_dependencies()
|
465
453
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
466
454
|
|
467
455
|
transform_kwargs = dict(
|
@@ -528,16 +516,40 @@ class LGBMRegressor(BaseTransformer):
|
|
528
516
|
self._is_fitted = True
|
529
517
|
return output_result
|
530
518
|
|
519
|
+
|
520
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
521
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
522
|
+
""" Method not supported for this class.
|
531
523
|
|
532
|
-
|
533
|
-
|
534
|
-
|
524
|
+
|
525
|
+
Raises:
|
526
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
527
|
+
|
528
|
+
Args:
|
529
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
530
|
+
Snowpark or Pandas DataFrame.
|
531
|
+
output_cols_prefix: Prefix for the response columns
|
535
532
|
Returns:
|
536
533
|
Transformed dataset.
|
537
534
|
"""
|
538
|
-
self.
|
539
|
-
|
540
|
-
|
535
|
+
self._infer_input_output_cols(dataset)
|
536
|
+
super()._check_dataset_type(dataset)
|
537
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
538
|
+
estimator=self._sklearn_object,
|
539
|
+
dataset=dataset,
|
540
|
+
input_cols=self.input_cols,
|
541
|
+
label_cols=self.label_cols,
|
542
|
+
sample_weight_col=self.sample_weight_col,
|
543
|
+
autogenerated=self._autogenerated,
|
544
|
+
subproject=_SUBPROJECT,
|
545
|
+
)
|
546
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
547
|
+
drop_input_cols=self._drop_input_cols,
|
548
|
+
expected_output_cols_list=self.output_cols,
|
549
|
+
)
|
550
|
+
self._sklearn_object = fitted_estimator
|
551
|
+
self._is_fitted = True
|
552
|
+
return output_result
|
541
553
|
|
542
554
|
|
543
555
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -628,10 +640,8 @@ class LGBMRegressor(BaseTransformer):
|
|
628
640
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
629
641
|
|
630
642
|
if isinstance(dataset, DataFrame):
|
631
|
-
self.
|
632
|
-
|
633
|
-
inference_method=inference_method,
|
634
|
-
)
|
643
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
644
|
+
self._deps = self._get_dependencies()
|
635
645
|
assert isinstance(
|
636
646
|
dataset._session, Session
|
637
647
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -696,10 +706,8 @@ class LGBMRegressor(BaseTransformer):
|
|
696
706
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
697
707
|
|
698
708
|
if isinstance(dataset, DataFrame):
|
699
|
-
self.
|
700
|
-
|
701
|
-
inference_method=inference_method,
|
702
|
-
)
|
709
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
710
|
+
self._deps = self._get_dependencies()
|
703
711
|
assert isinstance(
|
704
712
|
dataset._session, Session
|
705
713
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -761,10 +769,8 @@ class LGBMRegressor(BaseTransformer):
|
|
761
769
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
762
770
|
|
763
771
|
if isinstance(dataset, DataFrame):
|
764
|
-
self.
|
765
|
-
|
766
|
-
inference_method=inference_method,
|
767
|
-
)
|
772
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
773
|
+
self._deps = self._get_dependencies()
|
768
774
|
assert isinstance(
|
769
775
|
dataset._session, Session
|
770
776
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -830,10 +836,8 @@ class LGBMRegressor(BaseTransformer):
|
|
830
836
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
831
837
|
|
832
838
|
if isinstance(dataset, DataFrame):
|
833
|
-
self.
|
834
|
-
|
835
|
-
inference_method=inference_method,
|
836
|
-
)
|
839
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
840
|
+
self._deps = self._get_dependencies()
|
837
841
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
838
842
|
transform_kwargs = dict(
|
839
843
|
session=dataset._session,
|
@@ -897,17 +901,15 @@ class LGBMRegressor(BaseTransformer):
|
|
897
901
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
898
902
|
|
899
903
|
if isinstance(dataset, DataFrame):
|
900
|
-
self.
|
901
|
-
|
902
|
-
inference_method="score",
|
903
|
-
)
|
904
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
905
|
+
self._deps = self._get_dependencies()
|
904
906
|
selected_cols = self._get_active_columns()
|
905
907
|
if len(selected_cols) > 0:
|
906
908
|
dataset = dataset.select(selected_cols)
|
907
909
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
908
910
|
transform_kwargs = dict(
|
909
911
|
session=dataset._session,
|
910
|
-
dependencies=
|
912
|
+
dependencies=self._deps,
|
911
913
|
score_sproc_imports=['lightgbm', 'sklearn'],
|
912
914
|
)
|
913
915
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -972,11 +974,8 @@ class LGBMRegressor(BaseTransformer):
|
|
972
974
|
|
973
975
|
if isinstance(dataset, DataFrame):
|
974
976
|
|
975
|
-
self.
|
976
|
-
|
977
|
-
inference_method=inference_method,
|
978
|
-
|
979
|
-
)
|
977
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
978
|
+
self._deps = self._get_dependencies()
|
980
979
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
981
980
|
transform_kwargs = dict(
|
982
981
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class ARDRegression(BaseTransformer):
|
70
64
|
r"""Bayesian ARD regression
|
71
65
|
For more details on this class, see [sklearn.linear_model.ARDRegression]
|
@@ -319,20 +313,17 @@ class ARDRegression(BaseTransformer):
|
|
319
313
|
self,
|
320
314
|
dataset: DataFrame,
|
321
315
|
inference_method: str,
|
322
|
-
) ->
|
323
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
324
|
-
return the available package that exists in the snowflake anaconda channel
|
316
|
+
) -> None:
|
317
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
325
318
|
|
326
319
|
Args:
|
327
320
|
dataset: snowpark dataframe
|
328
321
|
inference_method: the inference method such as predict, score...
|
329
|
-
|
322
|
+
|
330
323
|
Raises:
|
331
324
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
332
325
|
SnowflakeMLException: If the session is None, raise error
|
333
326
|
|
334
|
-
Returns:
|
335
|
-
A list of available package that exists in the snowflake anaconda channel
|
336
327
|
"""
|
337
328
|
if not self._is_fitted:
|
338
329
|
raise exceptions.SnowflakeMLException(
|
@@ -350,9 +341,7 @@ class ARDRegression(BaseTransformer):
|
|
350
341
|
"Session must not specified for snowpark dataset."
|
351
342
|
),
|
352
343
|
)
|
353
|
-
|
354
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
355
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
344
|
+
|
356
345
|
|
357
346
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
358
347
|
@telemetry.send_api_usage_telemetry(
|
@@ -400,7 +389,8 @@ class ARDRegression(BaseTransformer):
|
|
400
389
|
|
401
390
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
402
391
|
|
403
|
-
self.
|
392
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
393
|
+
self._deps = self._get_dependencies()
|
404
394
|
assert isinstance(
|
405
395
|
dataset._session, Session
|
406
396
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -483,10 +473,8 @@ class ARDRegression(BaseTransformer):
|
|
483
473
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
484
474
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
485
475
|
|
486
|
-
self.
|
487
|
-
|
488
|
-
inference_method=inference_method,
|
489
|
-
)
|
476
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
477
|
+
self._deps = self._get_dependencies()
|
490
478
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
491
479
|
|
492
480
|
transform_kwargs = dict(
|
@@ -553,16 +541,40 @@ class ARDRegression(BaseTransformer):
|
|
553
541
|
self._is_fitted = True
|
554
542
|
return output_result
|
555
543
|
|
544
|
+
|
545
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
546
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
547
|
+
""" Method not supported for this class.
|
556
548
|
|
557
|
-
|
558
|
-
|
559
|
-
|
549
|
+
|
550
|
+
Raises:
|
551
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
552
|
+
|
553
|
+
Args:
|
554
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
555
|
+
Snowpark or Pandas DataFrame.
|
556
|
+
output_cols_prefix: Prefix for the response columns
|
560
557
|
Returns:
|
561
558
|
Transformed dataset.
|
562
559
|
"""
|
563
|
-
self.
|
564
|
-
|
565
|
-
|
560
|
+
self._infer_input_output_cols(dataset)
|
561
|
+
super()._check_dataset_type(dataset)
|
562
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
563
|
+
estimator=self._sklearn_object,
|
564
|
+
dataset=dataset,
|
565
|
+
input_cols=self.input_cols,
|
566
|
+
label_cols=self.label_cols,
|
567
|
+
sample_weight_col=self.sample_weight_col,
|
568
|
+
autogenerated=self._autogenerated,
|
569
|
+
subproject=_SUBPROJECT,
|
570
|
+
)
|
571
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
572
|
+
drop_input_cols=self._drop_input_cols,
|
573
|
+
expected_output_cols_list=self.output_cols,
|
574
|
+
)
|
575
|
+
self._sklearn_object = fitted_estimator
|
576
|
+
self._is_fitted = True
|
577
|
+
return output_result
|
566
578
|
|
567
579
|
|
568
580
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -653,10 +665,8 @@ class ARDRegression(BaseTransformer):
|
|
653
665
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
654
666
|
|
655
667
|
if isinstance(dataset, DataFrame):
|
656
|
-
self.
|
657
|
-
|
658
|
-
inference_method=inference_method,
|
659
|
-
)
|
668
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
669
|
+
self._deps = self._get_dependencies()
|
660
670
|
assert isinstance(
|
661
671
|
dataset._session, Session
|
662
672
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -721,10 +731,8 @@ class ARDRegression(BaseTransformer):
|
|
721
731
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
722
732
|
|
723
733
|
if isinstance(dataset, DataFrame):
|
724
|
-
self.
|
725
|
-
|
726
|
-
inference_method=inference_method,
|
727
|
-
)
|
734
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
735
|
+
self._deps = self._get_dependencies()
|
728
736
|
assert isinstance(
|
729
737
|
dataset._session, Session
|
730
738
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -786,10 +794,8 @@ class ARDRegression(BaseTransformer):
|
|
786
794
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
787
795
|
|
788
796
|
if isinstance(dataset, DataFrame):
|
789
|
-
self.
|
790
|
-
|
791
|
-
inference_method=inference_method,
|
792
|
-
)
|
797
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
798
|
+
self._deps = self._get_dependencies()
|
793
799
|
assert isinstance(
|
794
800
|
dataset._session, Session
|
795
801
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -855,10 +861,8 @@ class ARDRegression(BaseTransformer):
|
|
855
861
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
856
862
|
|
857
863
|
if isinstance(dataset, DataFrame):
|
858
|
-
self.
|
859
|
-
|
860
|
-
inference_method=inference_method,
|
861
|
-
)
|
864
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
865
|
+
self._deps = self._get_dependencies()
|
862
866
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
863
867
|
transform_kwargs = dict(
|
864
868
|
session=dataset._session,
|
@@ -922,17 +926,15 @@ class ARDRegression(BaseTransformer):
|
|
922
926
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
923
927
|
|
924
928
|
if isinstance(dataset, DataFrame):
|
925
|
-
self.
|
926
|
-
|
927
|
-
inference_method="score",
|
928
|
-
)
|
929
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
930
|
+
self._deps = self._get_dependencies()
|
929
931
|
selected_cols = self._get_active_columns()
|
930
932
|
if len(selected_cols) > 0:
|
931
933
|
dataset = dataset.select(selected_cols)
|
932
934
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
933
935
|
transform_kwargs = dict(
|
934
936
|
session=dataset._session,
|
935
|
-
dependencies=
|
937
|
+
dependencies=self._deps,
|
936
938
|
score_sproc_imports=['sklearn'],
|
937
939
|
)
|
938
940
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -997,11 +999,8 @@ class ARDRegression(BaseTransformer):
|
|
997
999
|
|
998
1000
|
if isinstance(dataset, DataFrame):
|
999
1001
|
|
1000
|
-
self.
|
1001
|
-
|
1002
|
-
inference_method=inference_method,
|
1003
|
-
|
1004
|
-
)
|
1002
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1003
|
+
self._deps = self._get_dependencies()
|
1005
1004
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1006
1005
|
transform_kwargs = dict(
|
1007
1006
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class BayesianRidge(BaseTransformer):
|
70
64
|
r"""Bayesian ridge regression
|
71
65
|
For more details on this class, see [sklearn.linear_model.BayesianRidge]
|
@@ -330,20 +324,17 @@ class BayesianRidge(BaseTransformer):
|
|
330
324
|
self,
|
331
325
|
dataset: DataFrame,
|
332
326
|
inference_method: str,
|
333
|
-
) ->
|
334
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
335
|
-
return the available package that exists in the snowflake anaconda channel
|
327
|
+
) -> None:
|
328
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
336
329
|
|
337
330
|
Args:
|
338
331
|
dataset: snowpark dataframe
|
339
332
|
inference_method: the inference method such as predict, score...
|
340
|
-
|
333
|
+
|
341
334
|
Raises:
|
342
335
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
343
336
|
SnowflakeMLException: If the session is None, raise error
|
344
337
|
|
345
|
-
Returns:
|
346
|
-
A list of available package that exists in the snowflake anaconda channel
|
347
338
|
"""
|
348
339
|
if not self._is_fitted:
|
349
340
|
raise exceptions.SnowflakeMLException(
|
@@ -361,9 +352,7 @@ class BayesianRidge(BaseTransformer):
|
|
361
352
|
"Session must not specified for snowpark dataset."
|
362
353
|
),
|
363
354
|
)
|
364
|
-
|
365
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
366
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
355
|
+
|
367
356
|
|
368
357
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
369
358
|
@telemetry.send_api_usage_telemetry(
|
@@ -411,7 +400,8 @@ class BayesianRidge(BaseTransformer):
|
|
411
400
|
|
412
401
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
413
402
|
|
414
|
-
self.
|
403
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
404
|
+
self._deps = self._get_dependencies()
|
415
405
|
assert isinstance(
|
416
406
|
dataset._session, Session
|
417
407
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -494,10 +484,8 @@ class BayesianRidge(BaseTransformer):
|
|
494
484
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
495
485
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
496
486
|
|
497
|
-
self.
|
498
|
-
|
499
|
-
inference_method=inference_method,
|
500
|
-
)
|
487
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
488
|
+
self._deps = self._get_dependencies()
|
501
489
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
502
490
|
|
503
491
|
transform_kwargs = dict(
|
@@ -564,16 +552,40 @@ class BayesianRidge(BaseTransformer):
|
|
564
552
|
self._is_fitted = True
|
565
553
|
return output_result
|
566
554
|
|
555
|
+
|
556
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
557
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
558
|
+
""" Method not supported for this class.
|
567
559
|
|
568
|
-
|
569
|
-
|
570
|
-
|
560
|
+
|
561
|
+
Raises:
|
562
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
563
|
+
|
564
|
+
Args:
|
565
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
566
|
+
Snowpark or Pandas DataFrame.
|
567
|
+
output_cols_prefix: Prefix for the response columns
|
571
568
|
Returns:
|
572
569
|
Transformed dataset.
|
573
570
|
"""
|
574
|
-
self.
|
575
|
-
|
576
|
-
|
571
|
+
self._infer_input_output_cols(dataset)
|
572
|
+
super()._check_dataset_type(dataset)
|
573
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
574
|
+
estimator=self._sklearn_object,
|
575
|
+
dataset=dataset,
|
576
|
+
input_cols=self.input_cols,
|
577
|
+
label_cols=self.label_cols,
|
578
|
+
sample_weight_col=self.sample_weight_col,
|
579
|
+
autogenerated=self._autogenerated,
|
580
|
+
subproject=_SUBPROJECT,
|
581
|
+
)
|
582
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
583
|
+
drop_input_cols=self._drop_input_cols,
|
584
|
+
expected_output_cols_list=self.output_cols,
|
585
|
+
)
|
586
|
+
self._sklearn_object = fitted_estimator
|
587
|
+
self._is_fitted = True
|
588
|
+
return output_result
|
577
589
|
|
578
590
|
|
579
591
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -664,10 +676,8 @@ class BayesianRidge(BaseTransformer):
|
|
664
676
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
665
677
|
|
666
678
|
if isinstance(dataset, DataFrame):
|
667
|
-
self.
|
668
|
-
|
669
|
-
inference_method=inference_method,
|
670
|
-
)
|
679
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
680
|
+
self._deps = self._get_dependencies()
|
671
681
|
assert isinstance(
|
672
682
|
dataset._session, Session
|
673
683
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -732,10 +742,8 @@ class BayesianRidge(BaseTransformer):
|
|
732
742
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
733
743
|
|
734
744
|
if isinstance(dataset, DataFrame):
|
735
|
-
self.
|
736
|
-
|
737
|
-
inference_method=inference_method,
|
738
|
-
)
|
745
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
746
|
+
self._deps = self._get_dependencies()
|
739
747
|
assert isinstance(
|
740
748
|
dataset._session, Session
|
741
749
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -797,10 +805,8 @@ class BayesianRidge(BaseTransformer):
|
|
797
805
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
798
806
|
|
799
807
|
if isinstance(dataset, DataFrame):
|
800
|
-
self.
|
801
|
-
|
802
|
-
inference_method=inference_method,
|
803
|
-
)
|
808
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
809
|
+
self._deps = self._get_dependencies()
|
804
810
|
assert isinstance(
|
805
811
|
dataset._session, Session
|
806
812
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -866,10 +872,8 @@ class BayesianRidge(BaseTransformer):
|
|
866
872
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
867
873
|
|
868
874
|
if isinstance(dataset, DataFrame):
|
869
|
-
self.
|
870
|
-
|
871
|
-
inference_method=inference_method,
|
872
|
-
)
|
875
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
876
|
+
self._deps = self._get_dependencies()
|
873
877
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
874
878
|
transform_kwargs = dict(
|
875
879
|
session=dataset._session,
|
@@ -933,17 +937,15 @@ class BayesianRidge(BaseTransformer):
|
|
933
937
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
934
938
|
|
935
939
|
if isinstance(dataset, DataFrame):
|
936
|
-
self.
|
937
|
-
|
938
|
-
inference_method="score",
|
939
|
-
)
|
940
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
941
|
+
self._deps = self._get_dependencies()
|
940
942
|
selected_cols = self._get_active_columns()
|
941
943
|
if len(selected_cols) > 0:
|
942
944
|
dataset = dataset.select(selected_cols)
|
943
945
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
944
946
|
transform_kwargs = dict(
|
945
947
|
session=dataset._session,
|
946
|
-
dependencies=
|
948
|
+
dependencies=self._deps,
|
947
949
|
score_sproc_imports=['sklearn'],
|
948
950
|
)
|
949
951
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1008,11 +1010,8 @@ class BayesianRidge(BaseTransformer):
|
|
1008
1010
|
|
1009
1011
|
if isinstance(dataset, DataFrame):
|
1010
1012
|
|
1011
|
-
self.
|
1012
|
-
|
1013
|
-
inference_method=inference_method,
|
1014
|
-
|
1015
|
-
)
|
1013
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1014
|
+
self._deps = self._get_dependencies()
|
1016
1015
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1017
1016
|
transform_kwargs = dict(
|
1018
1017
|
session = dataset._session,
|