snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -415,18 +415,24 @@ class LogisticRegressionCV(BaseTransformer):
|
|
415
415
|
self._get_model_signatures(dataset)
|
416
416
|
return self
|
417
417
|
|
418
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
419
|
-
if self._drop_input_cols:
|
420
|
-
return []
|
421
|
-
else:
|
422
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
423
|
-
|
424
418
|
def _batch_inference_validate_snowpark(
|
425
419
|
self,
|
426
420
|
dataset: DataFrame,
|
427
421
|
inference_method: str,
|
428
422
|
) -> List[str]:
|
429
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
423
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
424
|
+
return the available package that exists in the snowflake anaconda channel
|
425
|
+
|
426
|
+
Args:
|
427
|
+
dataset: snowpark dataframe
|
428
|
+
inference_method: the inference method such as predict, score...
|
429
|
+
|
430
|
+
Raises:
|
431
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
432
|
+
SnowflakeMLException: If the session is None, raise error
|
433
|
+
|
434
|
+
Returns:
|
435
|
+
A list of available package that exists in the snowflake anaconda channel
|
430
436
|
"""
|
431
437
|
if not self._is_fitted:
|
432
438
|
raise exceptions.SnowflakeMLException(
|
@@ -500,7 +506,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
500
506
|
transform_kwargs = dict(
|
501
507
|
session = dataset._session,
|
502
508
|
dependencies = self._deps,
|
503
|
-
|
509
|
+
drop_input_cols = self._drop_input_cols,
|
504
510
|
expected_output_cols_type = expected_type_inferred,
|
505
511
|
)
|
506
512
|
|
@@ -560,16 +566,16 @@ class LogisticRegressionCV(BaseTransformer):
|
|
560
566
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
561
567
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
562
568
|
# each row containing a list of values.
|
563
|
-
expected_dtype = "
|
569
|
+
expected_dtype = "array"
|
564
570
|
|
565
571
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
566
572
|
if expected_dtype == "":
|
567
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
573
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
568
574
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
569
|
-
expected_dtype = "
|
570
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
575
|
+
expected_dtype = "array"
|
576
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
571
577
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
572
|
-
expected_dtype = "
|
578
|
+
expected_dtype = "array"
|
573
579
|
else:
|
574
580
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
575
581
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -587,7 +593,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
587
593
|
transform_kwargs = dict(
|
588
594
|
session = dataset._session,
|
589
595
|
dependencies = self._deps,
|
590
|
-
|
596
|
+
drop_input_cols = self._drop_input_cols,
|
591
597
|
expected_output_cols_type = expected_dtype,
|
592
598
|
)
|
593
599
|
|
@@ -638,7 +644,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
638
644
|
subproject=_SUBPROJECT,
|
639
645
|
)
|
640
646
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
641
|
-
|
647
|
+
drop_input_cols=self._drop_input_cols,
|
642
648
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
643
649
|
)
|
644
650
|
self._sklearn_object = fitted_estimator
|
@@ -656,44 +662,6 @@ class LogisticRegressionCV(BaseTransformer):
|
|
656
662
|
assert self._sklearn_object is not None
|
657
663
|
return self._sklearn_object.embedding_
|
658
664
|
|
659
|
-
|
660
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
661
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
662
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
663
|
-
"""
|
664
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
665
|
-
if output_cols:
|
666
|
-
output_cols = [
|
667
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
668
|
-
for c in output_cols
|
669
|
-
]
|
670
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
671
|
-
output_cols = [output_cols_prefix]
|
672
|
-
elif self._sklearn_object is not None:
|
673
|
-
classes = self._sklearn_object.classes_
|
674
|
-
if isinstance(classes, numpy.ndarray):
|
675
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
676
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
677
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
678
|
-
output_cols = []
|
679
|
-
for i, cl in enumerate(classes):
|
680
|
-
# For binary classification, there is only one output column for each class
|
681
|
-
# ndarray as the two classes are complementary.
|
682
|
-
if len(cl) == 2:
|
683
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
684
|
-
else:
|
685
|
-
output_cols.extend([
|
686
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
687
|
-
])
|
688
|
-
else:
|
689
|
-
output_cols = []
|
690
|
-
|
691
|
-
# Make sure column names are valid snowflake identifiers.
|
692
|
-
assert output_cols is not None # Make MyPy happy
|
693
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
694
|
-
|
695
|
-
return rv
|
696
|
-
|
697
665
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
698
666
|
@telemetry.send_api_usage_telemetry(
|
699
667
|
project=_PROJECT,
|
@@ -735,7 +703,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
735
703
|
transform_kwargs = dict(
|
736
704
|
session=dataset._session,
|
737
705
|
dependencies=self._deps,
|
738
|
-
|
706
|
+
drop_input_cols = self._drop_input_cols,
|
739
707
|
expected_output_cols_type="float",
|
740
708
|
)
|
741
709
|
|
@@ -802,7 +770,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
802
770
|
transform_kwargs = dict(
|
803
771
|
session=dataset._session,
|
804
772
|
dependencies=self._deps,
|
805
|
-
|
773
|
+
drop_input_cols = self._drop_input_cols,
|
806
774
|
expected_output_cols_type="float",
|
807
775
|
)
|
808
776
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -865,7 +833,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
865
833
|
transform_kwargs = dict(
|
866
834
|
session=dataset._session,
|
867
835
|
dependencies=self._deps,
|
868
|
-
|
836
|
+
drop_input_cols = self._drop_input_cols,
|
869
837
|
expected_output_cols_type="float",
|
870
838
|
)
|
871
839
|
|
@@ -930,7 +898,7 @@ class LogisticRegressionCV(BaseTransformer):
|
|
930
898
|
transform_kwargs = dict(
|
931
899
|
session=dataset._session,
|
932
900
|
dependencies=self._deps,
|
933
|
-
|
901
|
+
drop_input_cols = self._drop_input_cols,
|
934
902
|
expected_output_cols_type="float",
|
935
903
|
)
|
936
904
|
|
@@ -986,13 +954,17 @@ class LogisticRegressionCV(BaseTransformer):
|
|
986
954
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
987
955
|
|
988
956
|
if isinstance(dataset, DataFrame):
|
957
|
+
self._deps = self._batch_inference_validate_snowpark(
|
958
|
+
dataset=dataset,
|
959
|
+
inference_method="score",
|
960
|
+
)
|
989
961
|
selected_cols = self._get_active_columns()
|
990
962
|
if len(selected_cols) > 0:
|
991
963
|
dataset = dataset.select(selected_cols)
|
992
964
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
993
965
|
transform_kwargs = dict(
|
994
966
|
session=dataset._session,
|
995
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
967
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
996
968
|
score_sproc_imports=['sklearn'],
|
997
969
|
)
|
998
970
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1066,9 +1038,9 @@ class LogisticRegressionCV(BaseTransformer):
|
|
1066
1038
|
transform_kwargs = dict(
|
1067
1039
|
session = dataset._session,
|
1068
1040
|
dependencies = self._deps,
|
1069
|
-
|
1070
|
-
expected_output_cols_type
|
1071
|
-
n_neighbors =
|
1041
|
+
drop_input_cols = self._drop_input_cols,
|
1042
|
+
expected_output_cols_type="array",
|
1043
|
+
n_neighbors = n_neighbors,
|
1072
1044
|
return_distance = return_distance
|
1073
1045
|
)
|
1074
1046
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -313,18 +313,24 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
313
313
|
self._get_model_signatures(dataset)
|
314
314
|
return self
|
315
315
|
|
316
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
317
|
-
if self._drop_input_cols:
|
318
|
-
return []
|
319
|
-
else:
|
320
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
321
|
-
|
322
316
|
def _batch_inference_validate_snowpark(
|
323
317
|
self,
|
324
318
|
dataset: DataFrame,
|
325
319
|
inference_method: str,
|
326
320
|
) -> List[str]:
|
327
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
321
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
322
|
+
return the available package that exists in the snowflake anaconda channel
|
323
|
+
|
324
|
+
Args:
|
325
|
+
dataset: snowpark dataframe
|
326
|
+
inference_method: the inference method such as predict, score...
|
327
|
+
|
328
|
+
Raises:
|
329
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
330
|
+
SnowflakeMLException: If the session is None, raise error
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
A list of available package that exists in the snowflake anaconda channel
|
328
334
|
"""
|
329
335
|
if not self._is_fitted:
|
330
336
|
raise exceptions.SnowflakeMLException(
|
@@ -398,7 +404,7 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
398
404
|
transform_kwargs = dict(
|
399
405
|
session = dataset._session,
|
400
406
|
dependencies = self._deps,
|
401
|
-
|
407
|
+
drop_input_cols = self._drop_input_cols,
|
402
408
|
expected_output_cols_type = expected_type_inferred,
|
403
409
|
)
|
404
410
|
|
@@ -458,16 +464,16 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
458
464
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
459
465
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
460
466
|
# each row containing a list of values.
|
461
|
-
expected_dtype = "
|
467
|
+
expected_dtype = "array"
|
462
468
|
|
463
469
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
464
470
|
if expected_dtype == "":
|
465
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
471
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
466
472
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
467
|
-
expected_dtype = "
|
468
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
473
|
+
expected_dtype = "array"
|
474
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
469
475
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
470
|
-
expected_dtype = "
|
476
|
+
expected_dtype = "array"
|
471
477
|
else:
|
472
478
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
473
479
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -485,7 +491,7 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
485
491
|
transform_kwargs = dict(
|
486
492
|
session = dataset._session,
|
487
493
|
dependencies = self._deps,
|
488
|
-
|
494
|
+
drop_input_cols = self._drop_input_cols,
|
489
495
|
expected_output_cols_type = expected_dtype,
|
490
496
|
)
|
491
497
|
|
@@ -536,7 +542,7 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
536
542
|
subproject=_SUBPROJECT,
|
537
543
|
)
|
538
544
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
539
|
-
|
545
|
+
drop_input_cols=self._drop_input_cols,
|
540
546
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
541
547
|
)
|
542
548
|
self._sklearn_object = fitted_estimator
|
@@ -554,44 +560,6 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
554
560
|
assert self._sklearn_object is not None
|
555
561
|
return self._sklearn_object.embedding_
|
556
562
|
|
557
|
-
|
558
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
559
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
560
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
561
|
-
"""
|
562
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
563
|
-
if output_cols:
|
564
|
-
output_cols = [
|
565
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
566
|
-
for c in output_cols
|
567
|
-
]
|
568
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
569
|
-
output_cols = [output_cols_prefix]
|
570
|
-
elif self._sklearn_object is not None:
|
571
|
-
classes = self._sklearn_object.classes_
|
572
|
-
if isinstance(classes, numpy.ndarray):
|
573
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
574
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
575
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
576
|
-
output_cols = []
|
577
|
-
for i, cl in enumerate(classes):
|
578
|
-
# For binary classification, there is only one output column for each class
|
579
|
-
# ndarray as the two classes are complementary.
|
580
|
-
if len(cl) == 2:
|
581
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
582
|
-
else:
|
583
|
-
output_cols.extend([
|
584
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
585
|
-
])
|
586
|
-
else:
|
587
|
-
output_cols = []
|
588
|
-
|
589
|
-
# Make sure column names are valid snowflake identifiers.
|
590
|
-
assert output_cols is not None # Make MyPy happy
|
591
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
592
|
-
|
593
|
-
return rv
|
594
|
-
|
595
563
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
596
564
|
@telemetry.send_api_usage_telemetry(
|
597
565
|
project=_PROJECT,
|
@@ -631,7 +599,7 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
631
599
|
transform_kwargs = dict(
|
632
600
|
session=dataset._session,
|
633
601
|
dependencies=self._deps,
|
634
|
-
|
602
|
+
drop_input_cols = self._drop_input_cols,
|
635
603
|
expected_output_cols_type="float",
|
636
604
|
)
|
637
605
|
|
@@ -696,7 +664,7 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
696
664
|
transform_kwargs = dict(
|
697
665
|
session=dataset._session,
|
698
666
|
dependencies=self._deps,
|
699
|
-
|
667
|
+
drop_input_cols = self._drop_input_cols,
|
700
668
|
expected_output_cols_type="float",
|
701
669
|
)
|
702
670
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -757,7 +725,7 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
757
725
|
transform_kwargs = dict(
|
758
726
|
session=dataset._session,
|
759
727
|
dependencies=self._deps,
|
760
|
-
|
728
|
+
drop_input_cols = self._drop_input_cols,
|
761
729
|
expected_output_cols_type="float",
|
762
730
|
)
|
763
731
|
|
@@ -822,7 +790,7 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
822
790
|
transform_kwargs = dict(
|
823
791
|
session=dataset._session,
|
824
792
|
dependencies=self._deps,
|
825
|
-
|
793
|
+
drop_input_cols = self._drop_input_cols,
|
826
794
|
expected_output_cols_type="float",
|
827
795
|
)
|
828
796
|
|
@@ -878,13 +846,17 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
878
846
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
879
847
|
|
880
848
|
if isinstance(dataset, DataFrame):
|
849
|
+
self._deps = self._batch_inference_validate_snowpark(
|
850
|
+
dataset=dataset,
|
851
|
+
inference_method="score",
|
852
|
+
)
|
881
853
|
selected_cols = self._get_active_columns()
|
882
854
|
if len(selected_cols) > 0:
|
883
855
|
dataset = dataset.select(selected_cols)
|
884
856
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
885
857
|
transform_kwargs = dict(
|
886
858
|
session=dataset._session,
|
887
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
859
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
888
860
|
score_sproc_imports=['sklearn'],
|
889
861
|
)
|
890
862
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -958,9 +930,9 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
958
930
|
transform_kwargs = dict(
|
959
931
|
session = dataset._session,
|
960
932
|
dependencies = self._deps,
|
961
|
-
|
962
|
-
expected_output_cols_type
|
963
|
-
n_neighbors =
|
933
|
+
drop_input_cols = self._drop_input_cols,
|
934
|
+
expected_output_cols_type="array",
|
935
|
+
n_neighbors = n_neighbors,
|
964
936
|
return_distance = return_distance
|
965
937
|
)
|
966
938
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -354,18 +354,24 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
354
354
|
self._get_model_signatures(dataset)
|
355
355
|
return self
|
356
356
|
|
357
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
358
|
-
if self._drop_input_cols:
|
359
|
-
return []
|
360
|
-
else:
|
361
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
362
|
-
|
363
357
|
def _batch_inference_validate_snowpark(
|
364
358
|
self,
|
365
359
|
dataset: DataFrame,
|
366
360
|
inference_method: str,
|
367
361
|
) -> List[str]:
|
368
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
362
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
363
|
+
return the available package that exists in the snowflake anaconda channel
|
364
|
+
|
365
|
+
Args:
|
366
|
+
dataset: snowpark dataframe
|
367
|
+
inference_method: the inference method such as predict, score...
|
368
|
+
|
369
|
+
Raises:
|
370
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
371
|
+
SnowflakeMLException: If the session is None, raise error
|
372
|
+
|
373
|
+
Returns:
|
374
|
+
A list of available package that exists in the snowflake anaconda channel
|
369
375
|
"""
|
370
376
|
if not self._is_fitted:
|
371
377
|
raise exceptions.SnowflakeMLException(
|
@@ -439,7 +445,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
439
445
|
transform_kwargs = dict(
|
440
446
|
session = dataset._session,
|
441
447
|
dependencies = self._deps,
|
442
|
-
|
448
|
+
drop_input_cols = self._drop_input_cols,
|
443
449
|
expected_output_cols_type = expected_type_inferred,
|
444
450
|
)
|
445
451
|
|
@@ -499,16 +505,16 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
499
505
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
500
506
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
501
507
|
# each row containing a list of values.
|
502
|
-
expected_dtype = "
|
508
|
+
expected_dtype = "array"
|
503
509
|
|
504
510
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
505
511
|
if expected_dtype == "":
|
506
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
512
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
507
513
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
508
|
-
expected_dtype = "
|
509
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
514
|
+
expected_dtype = "array"
|
515
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
510
516
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
511
|
-
expected_dtype = "
|
517
|
+
expected_dtype = "array"
|
512
518
|
else:
|
513
519
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
514
520
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -526,7 +532,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
526
532
|
transform_kwargs = dict(
|
527
533
|
session = dataset._session,
|
528
534
|
dependencies = self._deps,
|
529
|
-
|
535
|
+
drop_input_cols = self._drop_input_cols,
|
530
536
|
expected_output_cols_type = expected_dtype,
|
531
537
|
)
|
532
538
|
|
@@ -577,7 +583,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
577
583
|
subproject=_SUBPROJECT,
|
578
584
|
)
|
579
585
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
580
|
-
|
586
|
+
drop_input_cols=self._drop_input_cols,
|
581
587
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
582
588
|
)
|
583
589
|
self._sklearn_object = fitted_estimator
|
@@ -595,44 +601,6 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
595
601
|
assert self._sklearn_object is not None
|
596
602
|
return self._sklearn_object.embedding_
|
597
603
|
|
598
|
-
|
599
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
600
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
601
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
602
|
-
"""
|
603
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
604
|
-
if output_cols:
|
605
|
-
output_cols = [
|
606
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
607
|
-
for c in output_cols
|
608
|
-
]
|
609
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
610
|
-
output_cols = [output_cols_prefix]
|
611
|
-
elif self._sklearn_object is not None:
|
612
|
-
classes = self._sklearn_object.classes_
|
613
|
-
if isinstance(classes, numpy.ndarray):
|
614
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
615
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
616
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
617
|
-
output_cols = []
|
618
|
-
for i, cl in enumerate(classes):
|
619
|
-
# For binary classification, there is only one output column for each class
|
620
|
-
# ndarray as the two classes are complementary.
|
621
|
-
if len(cl) == 2:
|
622
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
623
|
-
else:
|
624
|
-
output_cols.extend([
|
625
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
626
|
-
])
|
627
|
-
else:
|
628
|
-
output_cols = []
|
629
|
-
|
630
|
-
# Make sure column names are valid snowflake identifiers.
|
631
|
-
assert output_cols is not None # Make MyPy happy
|
632
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
633
|
-
|
634
|
-
return rv
|
635
|
-
|
636
604
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
637
605
|
@telemetry.send_api_usage_telemetry(
|
638
606
|
project=_PROJECT,
|
@@ -672,7 +640,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
672
640
|
transform_kwargs = dict(
|
673
641
|
session=dataset._session,
|
674
642
|
dependencies=self._deps,
|
675
|
-
|
643
|
+
drop_input_cols = self._drop_input_cols,
|
676
644
|
expected_output_cols_type="float",
|
677
645
|
)
|
678
646
|
|
@@ -737,7 +705,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
737
705
|
transform_kwargs = dict(
|
738
706
|
session=dataset._session,
|
739
707
|
dependencies=self._deps,
|
740
|
-
|
708
|
+
drop_input_cols = self._drop_input_cols,
|
741
709
|
expected_output_cols_type="float",
|
742
710
|
)
|
743
711
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -798,7 +766,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
798
766
|
transform_kwargs = dict(
|
799
767
|
session=dataset._session,
|
800
768
|
dependencies=self._deps,
|
801
|
-
|
769
|
+
drop_input_cols = self._drop_input_cols,
|
802
770
|
expected_output_cols_type="float",
|
803
771
|
)
|
804
772
|
|
@@ -863,7 +831,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
863
831
|
transform_kwargs = dict(
|
864
832
|
session=dataset._session,
|
865
833
|
dependencies=self._deps,
|
866
|
-
|
834
|
+
drop_input_cols = self._drop_input_cols,
|
867
835
|
expected_output_cols_type="float",
|
868
836
|
)
|
869
837
|
|
@@ -919,13 +887,17 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
919
887
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
920
888
|
|
921
889
|
if isinstance(dataset, DataFrame):
|
890
|
+
self._deps = self._batch_inference_validate_snowpark(
|
891
|
+
dataset=dataset,
|
892
|
+
inference_method="score",
|
893
|
+
)
|
922
894
|
selected_cols = self._get_active_columns()
|
923
895
|
if len(selected_cols) > 0:
|
924
896
|
dataset = dataset.select(selected_cols)
|
925
897
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
926
898
|
transform_kwargs = dict(
|
927
899
|
session=dataset._session,
|
928
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
900
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
929
901
|
score_sproc_imports=['sklearn'],
|
930
902
|
)
|
931
903
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -999,9 +971,9 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
999
971
|
transform_kwargs = dict(
|
1000
972
|
session = dataset._session,
|
1001
973
|
dependencies = self._deps,
|
1002
|
-
|
1003
|
-
expected_output_cols_type
|
1004
|
-
n_neighbors =
|
974
|
+
drop_input_cols = self._drop_input_cols,
|
975
|
+
expected_output_cols_type="array",
|
976
|
+
n_neighbors = n_neighbors,
|
1005
977
|
return_distance = return_distance
|
1006
978
|
)
|
1007
979
|
elif isinstance(dataset, pd.DataFrame):
|