snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MultiTaskElasticNetCV(BaseTransformer):
|
70
64
|
r"""Multi-task L1/L2 ElasticNet with built-in cross-validation
|
71
65
|
For more details on this class, see [sklearn.linear_model.MultiTaskElasticNetCV]
|
@@ -354,20 +348,17 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
354
348
|
self,
|
355
349
|
dataset: DataFrame,
|
356
350
|
inference_method: str,
|
357
|
-
) ->
|
358
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
359
|
-
return the available package that exists in the snowflake anaconda channel
|
351
|
+
) -> None:
|
352
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
360
353
|
|
361
354
|
Args:
|
362
355
|
dataset: snowpark dataframe
|
363
356
|
inference_method: the inference method such as predict, score...
|
364
|
-
|
357
|
+
|
365
358
|
Raises:
|
366
359
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
367
360
|
SnowflakeMLException: If the session is None, raise error
|
368
361
|
|
369
|
-
Returns:
|
370
|
-
A list of available package that exists in the snowflake anaconda channel
|
371
362
|
"""
|
372
363
|
if not self._is_fitted:
|
373
364
|
raise exceptions.SnowflakeMLException(
|
@@ -385,9 +376,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
385
376
|
"Session must not specified for snowpark dataset."
|
386
377
|
),
|
387
378
|
)
|
388
|
-
|
389
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
390
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
379
|
+
|
391
380
|
|
392
381
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
393
382
|
@telemetry.send_api_usage_telemetry(
|
@@ -435,7 +424,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
435
424
|
|
436
425
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
437
426
|
|
438
|
-
self.
|
427
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
428
|
+
self._deps = self._get_dependencies()
|
439
429
|
assert isinstance(
|
440
430
|
dataset._session, Session
|
441
431
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -518,10 +508,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
518
508
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
519
509
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
520
510
|
|
521
|
-
self.
|
522
|
-
|
523
|
-
inference_method=inference_method,
|
524
|
-
)
|
511
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
512
|
+
self._deps = self._get_dependencies()
|
525
513
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
526
514
|
|
527
515
|
transform_kwargs = dict(
|
@@ -588,16 +576,40 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
588
576
|
self._is_fitted = True
|
589
577
|
return output_result
|
590
578
|
|
579
|
+
|
580
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
581
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
582
|
+
""" Method not supported for this class.
|
591
583
|
|
592
|
-
|
593
|
-
|
594
|
-
|
584
|
+
|
585
|
+
Raises:
|
586
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
587
|
+
|
588
|
+
Args:
|
589
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
590
|
+
Snowpark or Pandas DataFrame.
|
591
|
+
output_cols_prefix: Prefix for the response columns
|
595
592
|
Returns:
|
596
593
|
Transformed dataset.
|
597
594
|
"""
|
598
|
-
self.
|
599
|
-
|
600
|
-
|
595
|
+
self._infer_input_output_cols(dataset)
|
596
|
+
super()._check_dataset_type(dataset)
|
597
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
598
|
+
estimator=self._sklearn_object,
|
599
|
+
dataset=dataset,
|
600
|
+
input_cols=self.input_cols,
|
601
|
+
label_cols=self.label_cols,
|
602
|
+
sample_weight_col=self.sample_weight_col,
|
603
|
+
autogenerated=self._autogenerated,
|
604
|
+
subproject=_SUBPROJECT,
|
605
|
+
)
|
606
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
607
|
+
drop_input_cols=self._drop_input_cols,
|
608
|
+
expected_output_cols_list=self.output_cols,
|
609
|
+
)
|
610
|
+
self._sklearn_object = fitted_estimator
|
611
|
+
self._is_fitted = True
|
612
|
+
return output_result
|
601
613
|
|
602
614
|
|
603
615
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -688,10 +700,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
688
700
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
689
701
|
|
690
702
|
if isinstance(dataset, DataFrame):
|
691
|
-
self.
|
692
|
-
|
693
|
-
inference_method=inference_method,
|
694
|
-
)
|
703
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
704
|
+
self._deps = self._get_dependencies()
|
695
705
|
assert isinstance(
|
696
706
|
dataset._session, Session
|
697
707
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -756,10 +766,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
756
766
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
757
767
|
|
758
768
|
if isinstance(dataset, DataFrame):
|
759
|
-
self.
|
760
|
-
|
761
|
-
inference_method=inference_method,
|
762
|
-
)
|
769
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
770
|
+
self._deps = self._get_dependencies()
|
763
771
|
assert isinstance(
|
764
772
|
dataset._session, Session
|
765
773
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -821,10 +829,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
821
829
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
822
830
|
|
823
831
|
if isinstance(dataset, DataFrame):
|
824
|
-
self.
|
825
|
-
|
826
|
-
inference_method=inference_method,
|
827
|
-
)
|
832
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
833
|
+
self._deps = self._get_dependencies()
|
828
834
|
assert isinstance(
|
829
835
|
dataset._session, Session
|
830
836
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -890,10 +896,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
890
896
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
891
897
|
|
892
898
|
if isinstance(dataset, DataFrame):
|
893
|
-
self.
|
894
|
-
|
895
|
-
inference_method=inference_method,
|
896
|
-
)
|
899
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
900
|
+
self._deps = self._get_dependencies()
|
897
901
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
898
902
|
transform_kwargs = dict(
|
899
903
|
session=dataset._session,
|
@@ -957,17 +961,15 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
957
961
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
958
962
|
|
959
963
|
if isinstance(dataset, DataFrame):
|
960
|
-
self.
|
961
|
-
|
962
|
-
inference_method="score",
|
963
|
-
)
|
964
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
965
|
+
self._deps = self._get_dependencies()
|
964
966
|
selected_cols = self._get_active_columns()
|
965
967
|
if len(selected_cols) > 0:
|
966
968
|
dataset = dataset.select(selected_cols)
|
967
969
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
968
970
|
transform_kwargs = dict(
|
969
971
|
session=dataset._session,
|
970
|
-
dependencies=
|
972
|
+
dependencies=self._deps,
|
971
973
|
score_sproc_imports=['sklearn'],
|
972
974
|
)
|
973
975
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1032,11 +1034,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
1032
1034
|
|
1033
1035
|
if isinstance(dataset, DataFrame):
|
1034
1036
|
|
1035
|
-
self.
|
1036
|
-
|
1037
|
-
inference_method=inference_method,
|
1038
|
-
|
1039
|
-
)
|
1037
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1038
|
+
self._deps = self._get_dependencies()
|
1040
1039
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1041
1040
|
transform_kwargs = dict(
|
1042
1041
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MultiTaskLasso(BaseTransformer):
|
70
64
|
r"""Multi-task Lasso model trained with L1/L2 mixed-norm as regularizer
|
71
65
|
For more details on this class, see [sklearn.linear_model.MultiTaskLasso]
|
@@ -305,20 +299,17 @@ class MultiTaskLasso(BaseTransformer):
|
|
305
299
|
self,
|
306
300
|
dataset: DataFrame,
|
307
301
|
inference_method: str,
|
308
|
-
) ->
|
309
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
310
|
-
return the available package that exists in the snowflake anaconda channel
|
302
|
+
) -> None:
|
303
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
311
304
|
|
312
305
|
Args:
|
313
306
|
dataset: snowpark dataframe
|
314
307
|
inference_method: the inference method such as predict, score...
|
315
|
-
|
308
|
+
|
316
309
|
Raises:
|
317
310
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
318
311
|
SnowflakeMLException: If the session is None, raise error
|
319
312
|
|
320
|
-
Returns:
|
321
|
-
A list of available package that exists in the snowflake anaconda channel
|
322
313
|
"""
|
323
314
|
if not self._is_fitted:
|
324
315
|
raise exceptions.SnowflakeMLException(
|
@@ -336,9 +327,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
336
327
|
"Session must not specified for snowpark dataset."
|
337
328
|
),
|
338
329
|
)
|
339
|
-
|
340
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
341
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
330
|
+
|
342
331
|
|
343
332
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
344
333
|
@telemetry.send_api_usage_telemetry(
|
@@ -386,7 +375,8 @@ class MultiTaskLasso(BaseTransformer):
|
|
386
375
|
|
387
376
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
388
377
|
|
389
|
-
self.
|
378
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
379
|
+
self._deps = self._get_dependencies()
|
390
380
|
assert isinstance(
|
391
381
|
dataset._session, Session
|
392
382
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -469,10 +459,8 @@ class MultiTaskLasso(BaseTransformer):
|
|
469
459
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
470
460
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
471
461
|
|
472
|
-
self.
|
473
|
-
|
474
|
-
inference_method=inference_method,
|
475
|
-
)
|
462
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
463
|
+
self._deps = self._get_dependencies()
|
476
464
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
477
465
|
|
478
466
|
transform_kwargs = dict(
|
@@ -539,16 +527,40 @@ class MultiTaskLasso(BaseTransformer):
|
|
539
527
|
self._is_fitted = True
|
540
528
|
return output_result
|
541
529
|
|
530
|
+
|
531
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
532
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
533
|
+
""" Method not supported for this class.
|
542
534
|
|
543
|
-
|
544
|
-
|
545
|
-
|
535
|
+
|
536
|
+
Raises:
|
537
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
541
|
+
Snowpark or Pandas DataFrame.
|
542
|
+
output_cols_prefix: Prefix for the response columns
|
546
543
|
Returns:
|
547
544
|
Transformed dataset.
|
548
545
|
"""
|
549
|
-
self.
|
550
|
-
|
551
|
-
|
546
|
+
self._infer_input_output_cols(dataset)
|
547
|
+
super()._check_dataset_type(dataset)
|
548
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
549
|
+
estimator=self._sklearn_object,
|
550
|
+
dataset=dataset,
|
551
|
+
input_cols=self.input_cols,
|
552
|
+
label_cols=self.label_cols,
|
553
|
+
sample_weight_col=self.sample_weight_col,
|
554
|
+
autogenerated=self._autogenerated,
|
555
|
+
subproject=_SUBPROJECT,
|
556
|
+
)
|
557
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
558
|
+
drop_input_cols=self._drop_input_cols,
|
559
|
+
expected_output_cols_list=self.output_cols,
|
560
|
+
)
|
561
|
+
self._sklearn_object = fitted_estimator
|
562
|
+
self._is_fitted = True
|
563
|
+
return output_result
|
552
564
|
|
553
565
|
|
554
566
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -639,10 +651,8 @@ class MultiTaskLasso(BaseTransformer):
|
|
639
651
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
640
652
|
|
641
653
|
if isinstance(dataset, DataFrame):
|
642
|
-
self.
|
643
|
-
|
644
|
-
inference_method=inference_method,
|
645
|
-
)
|
654
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
655
|
+
self._deps = self._get_dependencies()
|
646
656
|
assert isinstance(
|
647
657
|
dataset._session, Session
|
648
658
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -707,10 +717,8 @@ class MultiTaskLasso(BaseTransformer):
|
|
707
717
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
708
718
|
|
709
719
|
if isinstance(dataset, DataFrame):
|
710
|
-
self.
|
711
|
-
|
712
|
-
inference_method=inference_method,
|
713
|
-
)
|
720
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
721
|
+
self._deps = self._get_dependencies()
|
714
722
|
assert isinstance(
|
715
723
|
dataset._session, Session
|
716
724
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -772,10 +780,8 @@ class MultiTaskLasso(BaseTransformer):
|
|
772
780
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
773
781
|
|
774
782
|
if isinstance(dataset, DataFrame):
|
775
|
-
self.
|
776
|
-
|
777
|
-
inference_method=inference_method,
|
778
|
-
)
|
783
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
784
|
+
self._deps = self._get_dependencies()
|
779
785
|
assert isinstance(
|
780
786
|
dataset._session, Session
|
781
787
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -841,10 +847,8 @@ class MultiTaskLasso(BaseTransformer):
|
|
841
847
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
842
848
|
|
843
849
|
if isinstance(dataset, DataFrame):
|
844
|
-
self.
|
845
|
-
|
846
|
-
inference_method=inference_method,
|
847
|
-
)
|
850
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
851
|
+
self._deps = self._get_dependencies()
|
848
852
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
849
853
|
transform_kwargs = dict(
|
850
854
|
session=dataset._session,
|
@@ -908,17 +912,15 @@ class MultiTaskLasso(BaseTransformer):
|
|
908
912
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
909
913
|
|
910
914
|
if isinstance(dataset, DataFrame):
|
911
|
-
self.
|
912
|
-
|
913
|
-
inference_method="score",
|
914
|
-
)
|
915
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
916
|
+
self._deps = self._get_dependencies()
|
915
917
|
selected_cols = self._get_active_columns()
|
916
918
|
if len(selected_cols) > 0:
|
917
919
|
dataset = dataset.select(selected_cols)
|
918
920
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
919
921
|
transform_kwargs = dict(
|
920
922
|
session=dataset._session,
|
921
|
-
dependencies=
|
923
|
+
dependencies=self._deps,
|
922
924
|
score_sproc_imports=['sklearn'],
|
923
925
|
)
|
924
926
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -983,11 +985,8 @@ class MultiTaskLasso(BaseTransformer):
|
|
983
985
|
|
984
986
|
if isinstance(dataset, DataFrame):
|
985
987
|
|
986
|
-
self.
|
987
|
-
|
988
|
-
inference_method=inference_method,
|
989
|
-
|
990
|
-
)
|
988
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
989
|
+
self._deps = self._get_dependencies()
|
991
990
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
992
991
|
transform_kwargs = dict(
|
993
992
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MultiTaskLassoCV(BaseTransformer):
|
70
64
|
r"""Multi-task Lasso model trained with L1/L2 mixed-norm as regularizer
|
71
65
|
For more details on this class, see [sklearn.linear_model.MultiTaskLassoCV]
|
@@ -340,20 +334,17 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
340
334
|
self,
|
341
335
|
dataset: DataFrame,
|
342
336
|
inference_method: str,
|
343
|
-
) ->
|
344
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
345
|
-
return the available package that exists in the snowflake anaconda channel
|
337
|
+
) -> None:
|
338
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
346
339
|
|
347
340
|
Args:
|
348
341
|
dataset: snowpark dataframe
|
349
342
|
inference_method: the inference method such as predict, score...
|
350
|
-
|
343
|
+
|
351
344
|
Raises:
|
352
345
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
353
346
|
SnowflakeMLException: If the session is None, raise error
|
354
347
|
|
355
|
-
Returns:
|
356
|
-
A list of available package that exists in the snowflake anaconda channel
|
357
348
|
"""
|
358
349
|
if not self._is_fitted:
|
359
350
|
raise exceptions.SnowflakeMLException(
|
@@ -371,9 +362,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
371
362
|
"Session must not specified for snowpark dataset."
|
372
363
|
),
|
373
364
|
)
|
374
|
-
|
375
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
376
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
365
|
+
|
377
366
|
|
378
367
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
379
368
|
@telemetry.send_api_usage_telemetry(
|
@@ -421,7 +410,8 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
421
410
|
|
422
411
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
423
412
|
|
424
|
-
self.
|
413
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
414
|
+
self._deps = self._get_dependencies()
|
425
415
|
assert isinstance(
|
426
416
|
dataset._session, Session
|
427
417
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -504,10 +494,8 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
504
494
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
505
495
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
506
496
|
|
507
|
-
self.
|
508
|
-
|
509
|
-
inference_method=inference_method,
|
510
|
-
)
|
497
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
498
|
+
self._deps = self._get_dependencies()
|
511
499
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
512
500
|
|
513
501
|
transform_kwargs = dict(
|
@@ -574,16 +562,40 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
574
562
|
self._is_fitted = True
|
575
563
|
return output_result
|
576
564
|
|
565
|
+
|
566
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
567
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
568
|
+
""" Method not supported for this class.
|
577
569
|
|
578
|
-
|
579
|
-
|
580
|
-
|
570
|
+
|
571
|
+
Raises:
|
572
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
573
|
+
|
574
|
+
Args:
|
575
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
576
|
+
Snowpark or Pandas DataFrame.
|
577
|
+
output_cols_prefix: Prefix for the response columns
|
581
578
|
Returns:
|
582
579
|
Transformed dataset.
|
583
580
|
"""
|
584
|
-
self.
|
585
|
-
|
586
|
-
|
581
|
+
self._infer_input_output_cols(dataset)
|
582
|
+
super()._check_dataset_type(dataset)
|
583
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
584
|
+
estimator=self._sklearn_object,
|
585
|
+
dataset=dataset,
|
586
|
+
input_cols=self.input_cols,
|
587
|
+
label_cols=self.label_cols,
|
588
|
+
sample_weight_col=self.sample_weight_col,
|
589
|
+
autogenerated=self._autogenerated,
|
590
|
+
subproject=_SUBPROJECT,
|
591
|
+
)
|
592
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
593
|
+
drop_input_cols=self._drop_input_cols,
|
594
|
+
expected_output_cols_list=self.output_cols,
|
595
|
+
)
|
596
|
+
self._sklearn_object = fitted_estimator
|
597
|
+
self._is_fitted = True
|
598
|
+
return output_result
|
587
599
|
|
588
600
|
|
589
601
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -674,10 +686,8 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
674
686
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
675
687
|
|
676
688
|
if isinstance(dataset, DataFrame):
|
677
|
-
self.
|
678
|
-
|
679
|
-
inference_method=inference_method,
|
680
|
-
)
|
689
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
690
|
+
self._deps = self._get_dependencies()
|
681
691
|
assert isinstance(
|
682
692
|
dataset._session, Session
|
683
693
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -742,10 +752,8 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
742
752
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
743
753
|
|
744
754
|
if isinstance(dataset, DataFrame):
|
745
|
-
self.
|
746
|
-
|
747
|
-
inference_method=inference_method,
|
748
|
-
)
|
755
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
756
|
+
self._deps = self._get_dependencies()
|
749
757
|
assert isinstance(
|
750
758
|
dataset._session, Session
|
751
759
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -807,10 +815,8 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
807
815
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
808
816
|
|
809
817
|
if isinstance(dataset, DataFrame):
|
810
|
-
self.
|
811
|
-
|
812
|
-
inference_method=inference_method,
|
813
|
-
)
|
818
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
819
|
+
self._deps = self._get_dependencies()
|
814
820
|
assert isinstance(
|
815
821
|
dataset._session, Session
|
816
822
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -876,10 +882,8 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
876
882
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
877
883
|
|
878
884
|
if isinstance(dataset, DataFrame):
|
879
|
-
self.
|
880
|
-
|
881
|
-
inference_method=inference_method,
|
882
|
-
)
|
885
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
886
|
+
self._deps = self._get_dependencies()
|
883
887
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
884
888
|
transform_kwargs = dict(
|
885
889
|
session=dataset._session,
|
@@ -943,17 +947,15 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
943
947
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
944
948
|
|
945
949
|
if isinstance(dataset, DataFrame):
|
946
|
-
self.
|
947
|
-
|
948
|
-
inference_method="score",
|
949
|
-
)
|
950
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
951
|
+
self._deps = self._get_dependencies()
|
950
952
|
selected_cols = self._get_active_columns()
|
951
953
|
if len(selected_cols) > 0:
|
952
954
|
dataset = dataset.select(selected_cols)
|
953
955
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
954
956
|
transform_kwargs = dict(
|
955
957
|
session=dataset._session,
|
956
|
-
dependencies=
|
958
|
+
dependencies=self._deps,
|
957
959
|
score_sproc_imports=['sklearn'],
|
958
960
|
)
|
959
961
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1018,11 +1020,8 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
1018
1020
|
|
1019
1021
|
if isinstance(dataset, DataFrame):
|
1020
1022
|
|
1021
|
-
self.
|
1022
|
-
|
1023
|
-
inference_method=inference_method,
|
1024
|
-
|
1025
|
-
)
|
1023
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1024
|
+
self._deps = self._get_dependencies()
|
1026
1025
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1027
1026
|
transform_kwargs = dict(
|
1028
1027
|
session = dataset._session,
|