snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LogisticRegression(BaseTransformer):
|
70
64
|
r"""Logistic Regression (aka logit, MaxEnt) classifier
|
71
65
|
For more details on this class, see [sklearn.linear_model.LogisticRegression]
|
@@ -394,20 +388,17 @@ class LogisticRegression(BaseTransformer):
|
|
394
388
|
self,
|
395
389
|
dataset: DataFrame,
|
396
390
|
inference_method: str,
|
397
|
-
) ->
|
398
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
399
|
-
return the available package that exists in the snowflake anaconda channel
|
391
|
+
) -> None:
|
392
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
400
393
|
|
401
394
|
Args:
|
402
395
|
dataset: snowpark dataframe
|
403
396
|
inference_method: the inference method such as predict, score...
|
404
|
-
|
397
|
+
|
405
398
|
Raises:
|
406
399
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
407
400
|
SnowflakeMLException: If the session is None, raise error
|
408
401
|
|
409
|
-
Returns:
|
410
|
-
A list of available package that exists in the snowflake anaconda channel
|
411
402
|
"""
|
412
403
|
if not self._is_fitted:
|
413
404
|
raise exceptions.SnowflakeMLException(
|
@@ -425,9 +416,7 @@ class LogisticRegression(BaseTransformer):
|
|
425
416
|
"Session must not specified for snowpark dataset."
|
426
417
|
),
|
427
418
|
)
|
428
|
-
|
429
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
430
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
419
|
+
|
431
420
|
|
432
421
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
433
422
|
@telemetry.send_api_usage_telemetry(
|
@@ -475,7 +464,8 @@ class LogisticRegression(BaseTransformer):
|
|
475
464
|
|
476
465
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
477
466
|
|
478
|
-
self.
|
467
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
468
|
+
self._deps = self._get_dependencies()
|
479
469
|
assert isinstance(
|
480
470
|
dataset._session, Session
|
481
471
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -558,10 +548,8 @@ class LogisticRegression(BaseTransformer):
|
|
558
548
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
559
549
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
560
550
|
|
561
|
-
self.
|
562
|
-
|
563
|
-
inference_method=inference_method,
|
564
|
-
)
|
551
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
552
|
+
self._deps = self._get_dependencies()
|
565
553
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
566
554
|
|
567
555
|
transform_kwargs = dict(
|
@@ -628,16 +616,40 @@ class LogisticRegression(BaseTransformer):
|
|
628
616
|
self._is_fitted = True
|
629
617
|
return output_result
|
630
618
|
|
619
|
+
|
620
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
621
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
622
|
+
""" Method not supported for this class.
|
631
623
|
|
632
|
-
|
633
|
-
|
634
|
-
|
624
|
+
|
625
|
+
Raises:
|
626
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
627
|
+
|
628
|
+
Args:
|
629
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
630
|
+
Snowpark or Pandas DataFrame.
|
631
|
+
output_cols_prefix: Prefix for the response columns
|
635
632
|
Returns:
|
636
633
|
Transformed dataset.
|
637
634
|
"""
|
638
|
-
self.
|
639
|
-
|
640
|
-
|
635
|
+
self._infer_input_output_cols(dataset)
|
636
|
+
super()._check_dataset_type(dataset)
|
637
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
638
|
+
estimator=self._sklearn_object,
|
639
|
+
dataset=dataset,
|
640
|
+
input_cols=self.input_cols,
|
641
|
+
label_cols=self.label_cols,
|
642
|
+
sample_weight_col=self.sample_weight_col,
|
643
|
+
autogenerated=self._autogenerated,
|
644
|
+
subproject=_SUBPROJECT,
|
645
|
+
)
|
646
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
647
|
+
drop_input_cols=self._drop_input_cols,
|
648
|
+
expected_output_cols_list=self.output_cols,
|
649
|
+
)
|
650
|
+
self._sklearn_object = fitted_estimator
|
651
|
+
self._is_fitted = True
|
652
|
+
return output_result
|
641
653
|
|
642
654
|
|
643
655
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -730,10 +742,8 @@ class LogisticRegression(BaseTransformer):
|
|
730
742
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
731
743
|
|
732
744
|
if isinstance(dataset, DataFrame):
|
733
|
-
self.
|
734
|
-
|
735
|
-
inference_method=inference_method,
|
736
|
-
)
|
745
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
746
|
+
self._deps = self._get_dependencies()
|
737
747
|
assert isinstance(
|
738
748
|
dataset._session, Session
|
739
749
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -800,10 +810,8 @@ class LogisticRegression(BaseTransformer):
|
|
800
810
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
801
811
|
|
802
812
|
if isinstance(dataset, DataFrame):
|
803
|
-
self.
|
804
|
-
|
805
|
-
inference_method=inference_method,
|
806
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
807
815
|
assert isinstance(
|
808
816
|
dataset._session, Session
|
809
817
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -867,10 +875,8 @@ class LogisticRegression(BaseTransformer):
|
|
867
875
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
868
876
|
|
869
877
|
if isinstance(dataset, DataFrame):
|
870
|
-
self.
|
871
|
-
|
872
|
-
inference_method=inference_method,
|
873
|
-
)
|
878
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
879
|
+
self._deps = self._get_dependencies()
|
874
880
|
assert isinstance(
|
875
881
|
dataset._session, Session
|
876
882
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -936,10 +942,8 @@ class LogisticRegression(BaseTransformer):
|
|
936
942
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
937
943
|
|
938
944
|
if isinstance(dataset, DataFrame):
|
939
|
-
self.
|
940
|
-
|
941
|
-
inference_method=inference_method,
|
942
|
-
)
|
945
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
946
|
+
self._deps = self._get_dependencies()
|
943
947
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
944
948
|
transform_kwargs = dict(
|
945
949
|
session=dataset._session,
|
@@ -1003,17 +1007,15 @@ class LogisticRegression(BaseTransformer):
|
|
1003
1007
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1004
1008
|
|
1005
1009
|
if isinstance(dataset, DataFrame):
|
1006
|
-
self.
|
1007
|
-
|
1008
|
-
inference_method="score",
|
1009
|
-
)
|
1010
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1011
|
+
self._deps = self._get_dependencies()
|
1010
1012
|
selected_cols = self._get_active_columns()
|
1011
1013
|
if len(selected_cols) > 0:
|
1012
1014
|
dataset = dataset.select(selected_cols)
|
1013
1015
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1014
1016
|
transform_kwargs = dict(
|
1015
1017
|
session=dataset._session,
|
1016
|
-
dependencies=
|
1018
|
+
dependencies=self._deps,
|
1017
1019
|
score_sproc_imports=['sklearn'],
|
1018
1020
|
)
|
1019
1021
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1078,11 +1080,8 @@ class LogisticRegression(BaseTransformer):
|
|
1078
1080
|
|
1079
1081
|
if isinstance(dataset, DataFrame):
|
1080
1082
|
|
1081
|
-
self.
|
1082
|
-
|
1083
|
-
inference_method=inference_method,
|
1084
|
-
|
1085
|
-
)
|
1083
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1084
|
+
self._deps = self._get_dependencies()
|
1086
1085
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1087
1086
|
transform_kwargs = dict(
|
1088
1087
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LogisticRegressionCV(BaseTransformer):
|
70
64
|
r"""Logistic Regression CV (aka logit, MaxEnt) classifier
|
71
65
|
For more details on this class, see [sklearn.linear_model.LogisticRegressionCV]
|
@@ -415,20 +409,17 @@ class LogisticRegressionCV(BaseTransformer):
|
|
415
409
|
self,
|
416
410
|
dataset: DataFrame,
|
417
411
|
inference_method: str,
|
418
|
-
) ->
|
419
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
420
|
-
return the available package that exists in the snowflake anaconda channel
|
412
|
+
) -> None:
|
413
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
421
414
|
|
422
415
|
Args:
|
423
416
|
dataset: snowpark dataframe
|
424
417
|
inference_method: the inference method such as predict, score...
|
425
|
-
|
418
|
+
|
426
419
|
Raises:
|
427
420
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
428
421
|
SnowflakeMLException: If the session is None, raise error
|
429
422
|
|
430
|
-
Returns:
|
431
|
-
A list of available package that exists in the snowflake anaconda channel
|
432
423
|
"""
|
433
424
|
if not self._is_fitted:
|
434
425
|
raise exceptions.SnowflakeMLException(
|
@@ -446,9 +437,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
446
437
|
"Session must not specified for snowpark dataset."
|
447
438
|
),
|
448
439
|
)
|
449
|
-
|
450
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
451
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
440
|
+
|
452
441
|
|
453
442
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
454
443
|
@telemetry.send_api_usage_telemetry(
|
@@ -496,7 +485,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
496
485
|
|
497
486
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
498
487
|
|
499
|
-
self.
|
488
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
489
|
+
self._deps = self._get_dependencies()
|
500
490
|
assert isinstance(
|
501
491
|
dataset._session, Session
|
502
492
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -579,10 +569,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
579
569
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
580
570
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
581
571
|
|
582
|
-
self.
|
583
|
-
|
584
|
-
inference_method=inference_method,
|
585
|
-
)
|
572
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
573
|
+
self._deps = self._get_dependencies()
|
586
574
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
587
575
|
|
588
576
|
transform_kwargs = dict(
|
@@ -649,16 +637,40 @@ class LogisticRegressionCV(BaseTransformer):
|
|
649
637
|
self._is_fitted = True
|
650
638
|
return output_result
|
651
639
|
|
640
|
+
|
641
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
642
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
643
|
+
""" Method not supported for this class.
|
652
644
|
|
653
|
-
|
654
|
-
|
655
|
-
|
645
|
+
|
646
|
+
Raises:
|
647
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
651
|
+
Snowpark or Pandas DataFrame.
|
652
|
+
output_cols_prefix: Prefix for the response columns
|
656
653
|
Returns:
|
657
654
|
Transformed dataset.
|
658
655
|
"""
|
659
|
-
self.
|
660
|
-
|
661
|
-
|
656
|
+
self._infer_input_output_cols(dataset)
|
657
|
+
super()._check_dataset_type(dataset)
|
658
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
659
|
+
estimator=self._sklearn_object,
|
660
|
+
dataset=dataset,
|
661
|
+
input_cols=self.input_cols,
|
662
|
+
label_cols=self.label_cols,
|
663
|
+
sample_weight_col=self.sample_weight_col,
|
664
|
+
autogenerated=self._autogenerated,
|
665
|
+
subproject=_SUBPROJECT,
|
666
|
+
)
|
667
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
668
|
+
drop_input_cols=self._drop_input_cols,
|
669
|
+
expected_output_cols_list=self.output_cols,
|
670
|
+
)
|
671
|
+
self._sklearn_object = fitted_estimator
|
672
|
+
self._is_fitted = True
|
673
|
+
return output_result
|
662
674
|
|
663
675
|
|
664
676
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -751,10 +763,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
751
763
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
752
764
|
|
753
765
|
if isinstance(dataset, DataFrame):
|
754
|
-
self.
|
755
|
-
|
756
|
-
inference_method=inference_method,
|
757
|
-
)
|
766
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
767
|
+
self._deps = self._get_dependencies()
|
758
768
|
assert isinstance(
|
759
769
|
dataset._session, Session
|
760
770
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -821,10 +831,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
821
831
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
822
832
|
|
823
833
|
if isinstance(dataset, DataFrame):
|
824
|
-
self.
|
825
|
-
|
826
|
-
inference_method=inference_method,
|
827
|
-
)
|
834
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
835
|
+
self._deps = self._get_dependencies()
|
828
836
|
assert isinstance(
|
829
837
|
dataset._session, Session
|
830
838
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -888,10 +896,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
888
896
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
889
897
|
|
890
898
|
if isinstance(dataset, DataFrame):
|
891
|
-
self.
|
892
|
-
|
893
|
-
inference_method=inference_method,
|
894
|
-
)
|
899
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
900
|
+
self._deps = self._get_dependencies()
|
895
901
|
assert isinstance(
|
896
902
|
dataset._session, Session
|
897
903
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -957,10 +963,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
957
963
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
958
964
|
|
959
965
|
if isinstance(dataset, DataFrame):
|
960
|
-
self.
|
961
|
-
|
962
|
-
inference_method=inference_method,
|
963
|
-
)
|
966
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
967
|
+
self._deps = self._get_dependencies()
|
964
968
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
965
969
|
transform_kwargs = dict(
|
966
970
|
session=dataset._session,
|
@@ -1024,17 +1028,15 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1024
1028
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1025
1029
|
|
1026
1030
|
if isinstance(dataset, DataFrame):
|
1027
|
-
self.
|
1028
|
-
|
1029
|
-
inference_method="score",
|
1030
|
-
)
|
1031
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1032
|
+
self._deps = self._get_dependencies()
|
1031
1033
|
selected_cols = self._get_active_columns()
|
1032
1034
|
if len(selected_cols) > 0:
|
1033
1035
|
dataset = dataset.select(selected_cols)
|
1034
1036
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1035
1037
|
transform_kwargs = dict(
|
1036
1038
|
session=dataset._session,
|
1037
|
-
dependencies=
|
1039
|
+
dependencies=self._deps,
|
1038
1040
|
score_sproc_imports=['sklearn'],
|
1039
1041
|
)
|
1040
1042
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1099,11 +1101,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1099
1101
|
|
1100
1102
|
if isinstance(dataset, DataFrame):
|
1101
1103
|
|
1102
|
-
self.
|
1103
|
-
|
1104
|
-
inference_method=inference_method,
|
1105
|
-
|
1106
|
-
)
|
1104
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1105
|
+
self._deps = self._get_dependencies()
|
1107
1106
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1108
1107
|
transform_kwargs = dict(
|
1109
1108
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MultiTaskElasticNet(BaseTransformer):
|
70
64
|
r"""Multi-task ElasticNet model trained with L1/L2 mixed-norm as regularizer
|
71
65
|
For more details on this class, see [sklearn.linear_model.MultiTaskElasticNet]
|
@@ -313,20 +307,17 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
313
307
|
self,
|
314
308
|
dataset: DataFrame,
|
315
309
|
inference_method: str,
|
316
|
-
) ->
|
317
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
318
|
-
return the available package that exists in the snowflake anaconda channel
|
310
|
+
) -> None:
|
311
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
319
312
|
|
320
313
|
Args:
|
321
314
|
dataset: snowpark dataframe
|
322
315
|
inference_method: the inference method such as predict, score...
|
323
|
-
|
316
|
+
|
324
317
|
Raises:
|
325
318
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
326
319
|
SnowflakeMLException: If the session is None, raise error
|
327
320
|
|
328
|
-
Returns:
|
329
|
-
A list of available package that exists in the snowflake anaconda channel
|
330
321
|
"""
|
331
322
|
if not self._is_fitted:
|
332
323
|
raise exceptions.SnowflakeMLException(
|
@@ -344,9 +335,7 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
344
335
|
"Session must not specified for snowpark dataset."
|
345
336
|
),
|
346
337
|
)
|
347
|
-
|
348
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
349
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
338
|
+
|
350
339
|
|
351
340
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
352
341
|
@telemetry.send_api_usage_telemetry(
|
@@ -394,7 +383,8 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
394
383
|
|
395
384
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
396
385
|
|
397
|
-
self.
|
386
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
387
|
+
self._deps = self._get_dependencies()
|
398
388
|
assert isinstance(
|
399
389
|
dataset._session, Session
|
400
390
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -477,10 +467,8 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
477
467
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
478
468
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
479
469
|
|
480
|
-
self.
|
481
|
-
|
482
|
-
inference_method=inference_method,
|
483
|
-
)
|
470
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
471
|
+
self._deps = self._get_dependencies()
|
484
472
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
485
473
|
|
486
474
|
transform_kwargs = dict(
|
@@ -547,16 +535,40 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
547
535
|
self._is_fitted = True
|
548
536
|
return output_result
|
549
537
|
|
538
|
+
|
539
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
540
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
541
|
+
""" Method not supported for this class.
|
550
542
|
|
551
|
-
|
552
|
-
|
553
|
-
|
543
|
+
|
544
|
+
Raises:
|
545
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
546
|
+
|
547
|
+
Args:
|
548
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
549
|
+
Snowpark or Pandas DataFrame.
|
550
|
+
output_cols_prefix: Prefix for the response columns
|
554
551
|
Returns:
|
555
552
|
Transformed dataset.
|
556
553
|
"""
|
557
|
-
self.
|
558
|
-
|
559
|
-
|
554
|
+
self._infer_input_output_cols(dataset)
|
555
|
+
super()._check_dataset_type(dataset)
|
556
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
557
|
+
estimator=self._sklearn_object,
|
558
|
+
dataset=dataset,
|
559
|
+
input_cols=self.input_cols,
|
560
|
+
label_cols=self.label_cols,
|
561
|
+
sample_weight_col=self.sample_weight_col,
|
562
|
+
autogenerated=self._autogenerated,
|
563
|
+
subproject=_SUBPROJECT,
|
564
|
+
)
|
565
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
566
|
+
drop_input_cols=self._drop_input_cols,
|
567
|
+
expected_output_cols_list=self.output_cols,
|
568
|
+
)
|
569
|
+
self._sklearn_object = fitted_estimator
|
570
|
+
self._is_fitted = True
|
571
|
+
return output_result
|
560
572
|
|
561
573
|
|
562
574
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -647,10 +659,8 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
647
659
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
648
660
|
|
649
661
|
if isinstance(dataset, DataFrame):
|
650
|
-
self.
|
651
|
-
|
652
|
-
inference_method=inference_method,
|
653
|
-
)
|
662
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
663
|
+
self._deps = self._get_dependencies()
|
654
664
|
assert isinstance(
|
655
665
|
dataset._session, Session
|
656
666
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -715,10 +725,8 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
715
725
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
716
726
|
|
717
727
|
if isinstance(dataset, DataFrame):
|
718
|
-
self.
|
719
|
-
|
720
|
-
inference_method=inference_method,
|
721
|
-
)
|
728
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
729
|
+
self._deps = self._get_dependencies()
|
722
730
|
assert isinstance(
|
723
731
|
dataset._session, Session
|
724
732
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -780,10 +788,8 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
780
788
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
781
789
|
|
782
790
|
if isinstance(dataset, DataFrame):
|
783
|
-
self.
|
784
|
-
|
785
|
-
inference_method=inference_method,
|
786
|
-
)
|
791
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
792
|
+
self._deps = self._get_dependencies()
|
787
793
|
assert isinstance(
|
788
794
|
dataset._session, Session
|
789
795
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -849,10 +855,8 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
849
855
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
850
856
|
|
851
857
|
if isinstance(dataset, DataFrame):
|
852
|
-
self.
|
853
|
-
|
854
|
-
inference_method=inference_method,
|
855
|
-
)
|
858
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
859
|
+
self._deps = self._get_dependencies()
|
856
860
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
857
861
|
transform_kwargs = dict(
|
858
862
|
session=dataset._session,
|
@@ -916,17 +920,15 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
916
920
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
917
921
|
|
918
922
|
if isinstance(dataset, DataFrame):
|
919
|
-
self.
|
920
|
-
|
921
|
-
inference_method="score",
|
922
|
-
)
|
923
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
924
|
+
self._deps = self._get_dependencies()
|
923
925
|
selected_cols = self._get_active_columns()
|
924
926
|
if len(selected_cols) > 0:
|
925
927
|
dataset = dataset.select(selected_cols)
|
926
928
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
927
929
|
transform_kwargs = dict(
|
928
930
|
session=dataset._session,
|
929
|
-
dependencies=
|
931
|
+
dependencies=self._deps,
|
930
932
|
score_sproc_imports=['sklearn'],
|
931
933
|
)
|
932
934
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -991,11 +993,8 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
991
993
|
|
992
994
|
if isinstance(dataset, DataFrame):
|
993
995
|
|
994
|
-
self.
|
995
|
-
|
996
|
-
inference_method=inference_method,
|
997
|
-
|
998
|
-
)
|
996
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
997
|
+
self._deps = self._get_dependencies()
|
999
998
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1000
999
|
transform_kwargs = dict(
|
1001
1000
|
session = dataset._session,
|