snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -305,18 +305,24 @@ class MultiTaskLasso(BaseTransformer):
|
|
305
305
|
self._get_model_signatures(dataset)
|
306
306
|
return self
|
307
307
|
|
308
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
309
|
-
if self._drop_input_cols:
|
310
|
-
return []
|
311
|
-
else:
|
312
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
313
|
-
|
314
308
|
def _batch_inference_validate_snowpark(
|
315
309
|
self,
|
316
310
|
dataset: DataFrame,
|
317
311
|
inference_method: str,
|
318
312
|
) -> List[str]:
|
319
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
313
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
314
|
+
return the available package that exists in the snowflake anaconda channel
|
315
|
+
|
316
|
+
Args:
|
317
|
+
dataset: snowpark dataframe
|
318
|
+
inference_method: the inference method such as predict, score...
|
319
|
+
|
320
|
+
Raises:
|
321
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
322
|
+
SnowflakeMLException: If the session is None, raise error
|
323
|
+
|
324
|
+
Returns:
|
325
|
+
A list of available package that exists in the snowflake anaconda channel
|
320
326
|
"""
|
321
327
|
if not self._is_fitted:
|
322
328
|
raise exceptions.SnowflakeMLException(
|
@@ -390,7 +396,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
390
396
|
transform_kwargs = dict(
|
391
397
|
session = dataset._session,
|
392
398
|
dependencies = self._deps,
|
393
|
-
|
399
|
+
drop_input_cols = self._drop_input_cols,
|
394
400
|
expected_output_cols_type = expected_type_inferred,
|
395
401
|
)
|
396
402
|
|
@@ -450,16 +456,16 @@ class MultiTaskLasso(BaseTransformer):
|
|
450
456
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
451
457
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
452
458
|
# each row containing a list of values.
|
453
|
-
expected_dtype = "
|
459
|
+
expected_dtype = "array"
|
454
460
|
|
455
461
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
456
462
|
if expected_dtype == "":
|
457
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
463
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
458
464
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
459
|
-
expected_dtype = "
|
460
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
465
|
+
expected_dtype = "array"
|
466
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
461
467
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
462
|
-
expected_dtype = "
|
468
|
+
expected_dtype = "array"
|
463
469
|
else:
|
464
470
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
465
471
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -477,7 +483,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
477
483
|
transform_kwargs = dict(
|
478
484
|
session = dataset._session,
|
479
485
|
dependencies = self._deps,
|
480
|
-
|
486
|
+
drop_input_cols = self._drop_input_cols,
|
481
487
|
expected_output_cols_type = expected_dtype,
|
482
488
|
)
|
483
489
|
|
@@ -528,7 +534,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
528
534
|
subproject=_SUBPROJECT,
|
529
535
|
)
|
530
536
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
531
|
-
|
537
|
+
drop_input_cols=self._drop_input_cols,
|
532
538
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
533
539
|
)
|
534
540
|
self._sklearn_object = fitted_estimator
|
@@ -546,44 +552,6 @@ class MultiTaskLasso(BaseTransformer):
|
|
546
552
|
assert self._sklearn_object is not None
|
547
553
|
return self._sklearn_object.embedding_
|
548
554
|
|
549
|
-
|
550
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
551
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
552
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
553
|
-
"""
|
554
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
555
|
-
if output_cols:
|
556
|
-
output_cols = [
|
557
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
558
|
-
for c in output_cols
|
559
|
-
]
|
560
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
561
|
-
output_cols = [output_cols_prefix]
|
562
|
-
elif self._sklearn_object is not None:
|
563
|
-
classes = self._sklearn_object.classes_
|
564
|
-
if isinstance(classes, numpy.ndarray):
|
565
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
566
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
567
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
568
|
-
output_cols = []
|
569
|
-
for i, cl in enumerate(classes):
|
570
|
-
# For binary classification, there is only one output column for each class
|
571
|
-
# ndarray as the two classes are complementary.
|
572
|
-
if len(cl) == 2:
|
573
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
574
|
-
else:
|
575
|
-
output_cols.extend([
|
576
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
577
|
-
])
|
578
|
-
else:
|
579
|
-
output_cols = []
|
580
|
-
|
581
|
-
# Make sure column names are valid snowflake identifiers.
|
582
|
-
assert output_cols is not None # Make MyPy happy
|
583
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
584
|
-
|
585
|
-
return rv
|
586
|
-
|
587
555
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
588
556
|
@telemetry.send_api_usage_telemetry(
|
589
557
|
project=_PROJECT,
|
@@ -623,7 +591,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
623
591
|
transform_kwargs = dict(
|
624
592
|
session=dataset._session,
|
625
593
|
dependencies=self._deps,
|
626
|
-
|
594
|
+
drop_input_cols = self._drop_input_cols,
|
627
595
|
expected_output_cols_type="float",
|
628
596
|
)
|
629
597
|
|
@@ -688,7 +656,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
688
656
|
transform_kwargs = dict(
|
689
657
|
session=dataset._session,
|
690
658
|
dependencies=self._deps,
|
691
|
-
|
659
|
+
drop_input_cols = self._drop_input_cols,
|
692
660
|
expected_output_cols_type="float",
|
693
661
|
)
|
694
662
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -749,7 +717,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
749
717
|
transform_kwargs = dict(
|
750
718
|
session=dataset._session,
|
751
719
|
dependencies=self._deps,
|
752
|
-
|
720
|
+
drop_input_cols = self._drop_input_cols,
|
753
721
|
expected_output_cols_type="float",
|
754
722
|
)
|
755
723
|
|
@@ -814,7 +782,7 @@ class MultiTaskLasso(BaseTransformer):
|
|
814
782
|
transform_kwargs = dict(
|
815
783
|
session=dataset._session,
|
816
784
|
dependencies=self._deps,
|
817
|
-
|
785
|
+
drop_input_cols = self._drop_input_cols,
|
818
786
|
expected_output_cols_type="float",
|
819
787
|
)
|
820
788
|
|
@@ -870,13 +838,17 @@ class MultiTaskLasso(BaseTransformer):
|
|
870
838
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
871
839
|
|
872
840
|
if isinstance(dataset, DataFrame):
|
841
|
+
self._deps = self._batch_inference_validate_snowpark(
|
842
|
+
dataset=dataset,
|
843
|
+
inference_method="score",
|
844
|
+
)
|
873
845
|
selected_cols = self._get_active_columns()
|
874
846
|
if len(selected_cols) > 0:
|
875
847
|
dataset = dataset.select(selected_cols)
|
876
848
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
877
849
|
transform_kwargs = dict(
|
878
850
|
session=dataset._session,
|
879
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
851
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
880
852
|
score_sproc_imports=['sklearn'],
|
881
853
|
)
|
882
854
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -950,9 +922,9 @@ class MultiTaskLasso(BaseTransformer):
|
|
950
922
|
transform_kwargs = dict(
|
951
923
|
session = dataset._session,
|
952
924
|
dependencies = self._deps,
|
953
|
-
|
954
|
-
expected_output_cols_type
|
955
|
-
n_neighbors =
|
925
|
+
drop_input_cols = self._drop_input_cols,
|
926
|
+
expected_output_cols_type="array",
|
927
|
+
n_neighbors = n_neighbors,
|
956
928
|
return_distance = return_distance
|
957
929
|
)
|
958
930
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -340,18 +340,24 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
340
340
|
self._get_model_signatures(dataset)
|
341
341
|
return self
|
342
342
|
|
343
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
344
|
-
if self._drop_input_cols:
|
345
|
-
return []
|
346
|
-
else:
|
347
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
348
|
-
|
349
343
|
def _batch_inference_validate_snowpark(
|
350
344
|
self,
|
351
345
|
dataset: DataFrame,
|
352
346
|
inference_method: str,
|
353
347
|
) -> List[str]:
|
354
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
348
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
349
|
+
return the available package that exists in the snowflake anaconda channel
|
350
|
+
|
351
|
+
Args:
|
352
|
+
dataset: snowpark dataframe
|
353
|
+
inference_method: the inference method such as predict, score...
|
354
|
+
|
355
|
+
Raises:
|
356
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
357
|
+
SnowflakeMLException: If the session is None, raise error
|
358
|
+
|
359
|
+
Returns:
|
360
|
+
A list of available package that exists in the snowflake anaconda channel
|
355
361
|
"""
|
356
362
|
if not self._is_fitted:
|
357
363
|
raise exceptions.SnowflakeMLException(
|
@@ -425,7 +431,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
425
431
|
transform_kwargs = dict(
|
426
432
|
session = dataset._session,
|
427
433
|
dependencies = self._deps,
|
428
|
-
|
434
|
+
drop_input_cols = self._drop_input_cols,
|
429
435
|
expected_output_cols_type = expected_type_inferred,
|
430
436
|
)
|
431
437
|
|
@@ -485,16 +491,16 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
485
491
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
486
492
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
487
493
|
# each row containing a list of values.
|
488
|
-
expected_dtype = "
|
494
|
+
expected_dtype = "array"
|
489
495
|
|
490
496
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
491
497
|
if expected_dtype == "":
|
492
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
498
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
493
499
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
494
|
-
expected_dtype = "
|
495
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
500
|
+
expected_dtype = "array"
|
501
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
496
502
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
497
|
-
expected_dtype = "
|
503
|
+
expected_dtype = "array"
|
498
504
|
else:
|
499
505
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
500
506
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -512,7 +518,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
512
518
|
transform_kwargs = dict(
|
513
519
|
session = dataset._session,
|
514
520
|
dependencies = self._deps,
|
515
|
-
|
521
|
+
drop_input_cols = self._drop_input_cols,
|
516
522
|
expected_output_cols_type = expected_dtype,
|
517
523
|
)
|
518
524
|
|
@@ -563,7 +569,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
563
569
|
subproject=_SUBPROJECT,
|
564
570
|
)
|
565
571
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
566
|
-
|
572
|
+
drop_input_cols=self._drop_input_cols,
|
567
573
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
568
574
|
)
|
569
575
|
self._sklearn_object = fitted_estimator
|
@@ -581,44 +587,6 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
581
587
|
assert self._sklearn_object is not None
|
582
588
|
return self._sklearn_object.embedding_
|
583
589
|
|
584
|
-
|
585
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
586
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
587
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
588
|
-
"""
|
589
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
590
|
-
if output_cols:
|
591
|
-
output_cols = [
|
592
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
593
|
-
for c in output_cols
|
594
|
-
]
|
595
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
596
|
-
output_cols = [output_cols_prefix]
|
597
|
-
elif self._sklearn_object is not None:
|
598
|
-
classes = self._sklearn_object.classes_
|
599
|
-
if isinstance(classes, numpy.ndarray):
|
600
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
601
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
602
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
603
|
-
output_cols = []
|
604
|
-
for i, cl in enumerate(classes):
|
605
|
-
# For binary classification, there is only one output column for each class
|
606
|
-
# ndarray as the two classes are complementary.
|
607
|
-
if len(cl) == 2:
|
608
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
609
|
-
else:
|
610
|
-
output_cols.extend([
|
611
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
612
|
-
])
|
613
|
-
else:
|
614
|
-
output_cols = []
|
615
|
-
|
616
|
-
# Make sure column names are valid snowflake identifiers.
|
617
|
-
assert output_cols is not None # Make MyPy happy
|
618
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
619
|
-
|
620
|
-
return rv
|
621
|
-
|
622
590
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
623
591
|
@telemetry.send_api_usage_telemetry(
|
624
592
|
project=_PROJECT,
|
@@ -658,7 +626,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
658
626
|
transform_kwargs = dict(
|
659
627
|
session=dataset._session,
|
660
628
|
dependencies=self._deps,
|
661
|
-
|
629
|
+
drop_input_cols = self._drop_input_cols,
|
662
630
|
expected_output_cols_type="float",
|
663
631
|
)
|
664
632
|
|
@@ -723,7 +691,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
723
691
|
transform_kwargs = dict(
|
724
692
|
session=dataset._session,
|
725
693
|
dependencies=self._deps,
|
726
|
-
|
694
|
+
drop_input_cols = self._drop_input_cols,
|
727
695
|
expected_output_cols_type="float",
|
728
696
|
)
|
729
697
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -784,7 +752,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
784
752
|
transform_kwargs = dict(
|
785
753
|
session=dataset._session,
|
786
754
|
dependencies=self._deps,
|
787
|
-
|
755
|
+
drop_input_cols = self._drop_input_cols,
|
788
756
|
expected_output_cols_type="float",
|
789
757
|
)
|
790
758
|
|
@@ -849,7 +817,7 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
849
817
|
transform_kwargs = dict(
|
850
818
|
session=dataset._session,
|
851
819
|
dependencies=self._deps,
|
852
|
-
|
820
|
+
drop_input_cols = self._drop_input_cols,
|
853
821
|
expected_output_cols_type="float",
|
854
822
|
)
|
855
823
|
|
@@ -905,13 +873,17 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
905
873
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
906
874
|
|
907
875
|
if isinstance(dataset, DataFrame):
|
876
|
+
self._deps = self._batch_inference_validate_snowpark(
|
877
|
+
dataset=dataset,
|
878
|
+
inference_method="score",
|
879
|
+
)
|
908
880
|
selected_cols = self._get_active_columns()
|
909
881
|
if len(selected_cols) > 0:
|
910
882
|
dataset = dataset.select(selected_cols)
|
911
883
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
912
884
|
transform_kwargs = dict(
|
913
885
|
session=dataset._session,
|
914
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
886
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
915
887
|
score_sproc_imports=['sklearn'],
|
916
888
|
)
|
917
889
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -985,9 +957,9 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
985
957
|
transform_kwargs = dict(
|
986
958
|
session = dataset._session,
|
987
959
|
dependencies = self._deps,
|
988
|
-
|
989
|
-
expected_output_cols_type
|
990
|
-
n_neighbors =
|
960
|
+
drop_input_cols = self._drop_input_cols,
|
961
|
+
expected_output_cols_type="array",
|
962
|
+
n_neighbors = n_neighbors,
|
991
963
|
return_distance = return_distance
|
992
964
|
)
|
993
965
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -288,18 +288,24 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
288
288
|
self._get_model_signatures(dataset)
|
289
289
|
return self
|
290
290
|
|
291
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
292
|
-
if self._drop_input_cols:
|
293
|
-
return []
|
294
|
-
else:
|
295
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
296
|
-
|
297
291
|
def _batch_inference_validate_snowpark(
|
298
292
|
self,
|
299
293
|
dataset: DataFrame,
|
300
294
|
inference_method: str,
|
301
295
|
) -> List[str]:
|
302
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
296
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
297
|
+
return the available package that exists in the snowflake anaconda channel
|
298
|
+
|
299
|
+
Args:
|
300
|
+
dataset: snowpark dataframe
|
301
|
+
inference_method: the inference method such as predict, score...
|
302
|
+
|
303
|
+
Raises:
|
304
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
305
|
+
SnowflakeMLException: If the session is None, raise error
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
A list of available package that exists in the snowflake anaconda channel
|
303
309
|
"""
|
304
310
|
if not self._is_fitted:
|
305
311
|
raise exceptions.SnowflakeMLException(
|
@@ -373,7 +379,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
373
379
|
transform_kwargs = dict(
|
374
380
|
session = dataset._session,
|
375
381
|
dependencies = self._deps,
|
376
|
-
|
382
|
+
drop_input_cols = self._drop_input_cols,
|
377
383
|
expected_output_cols_type = expected_type_inferred,
|
378
384
|
)
|
379
385
|
|
@@ -433,16 +439,16 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
433
439
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
434
440
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
435
441
|
# each row containing a list of values.
|
436
|
-
expected_dtype = "
|
442
|
+
expected_dtype = "array"
|
437
443
|
|
438
444
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
439
445
|
if expected_dtype == "":
|
440
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
446
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
441
447
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
442
|
-
expected_dtype = "
|
443
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
448
|
+
expected_dtype = "array"
|
449
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
444
450
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
445
|
-
expected_dtype = "
|
451
|
+
expected_dtype = "array"
|
446
452
|
else:
|
447
453
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
448
454
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -460,7 +466,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
460
466
|
transform_kwargs = dict(
|
461
467
|
session = dataset._session,
|
462
468
|
dependencies = self._deps,
|
463
|
-
|
469
|
+
drop_input_cols = self._drop_input_cols,
|
464
470
|
expected_output_cols_type = expected_dtype,
|
465
471
|
)
|
466
472
|
|
@@ -511,7 +517,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
511
517
|
subproject=_SUBPROJECT,
|
512
518
|
)
|
513
519
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
514
|
-
|
520
|
+
drop_input_cols=self._drop_input_cols,
|
515
521
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
516
522
|
)
|
517
523
|
self._sklearn_object = fitted_estimator
|
@@ -529,44 +535,6 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
529
535
|
assert self._sklearn_object is not None
|
530
536
|
return self._sklearn_object.embedding_
|
531
537
|
|
532
|
-
|
533
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
534
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
535
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
536
|
-
"""
|
537
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
538
|
-
if output_cols:
|
539
|
-
output_cols = [
|
540
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
541
|
-
for c in output_cols
|
542
|
-
]
|
543
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
544
|
-
output_cols = [output_cols_prefix]
|
545
|
-
elif self._sklearn_object is not None:
|
546
|
-
classes = self._sklearn_object.classes_
|
547
|
-
if isinstance(classes, numpy.ndarray):
|
548
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
549
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
550
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
551
|
-
output_cols = []
|
552
|
-
for i, cl in enumerate(classes):
|
553
|
-
# For binary classification, there is only one output column for each class
|
554
|
-
# ndarray as the two classes are complementary.
|
555
|
-
if len(cl) == 2:
|
556
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
557
|
-
else:
|
558
|
-
output_cols.extend([
|
559
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
560
|
-
])
|
561
|
-
else:
|
562
|
-
output_cols = []
|
563
|
-
|
564
|
-
# Make sure column names are valid snowflake identifiers.
|
565
|
-
assert output_cols is not None # Make MyPy happy
|
566
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
567
|
-
|
568
|
-
return rv
|
569
|
-
|
570
538
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
571
539
|
@telemetry.send_api_usage_telemetry(
|
572
540
|
project=_PROJECT,
|
@@ -606,7 +574,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
606
574
|
transform_kwargs = dict(
|
607
575
|
session=dataset._session,
|
608
576
|
dependencies=self._deps,
|
609
|
-
|
577
|
+
drop_input_cols = self._drop_input_cols,
|
610
578
|
expected_output_cols_type="float",
|
611
579
|
)
|
612
580
|
|
@@ -671,7 +639,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
671
639
|
transform_kwargs = dict(
|
672
640
|
session=dataset._session,
|
673
641
|
dependencies=self._deps,
|
674
|
-
|
642
|
+
drop_input_cols = self._drop_input_cols,
|
675
643
|
expected_output_cols_type="float",
|
676
644
|
)
|
677
645
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -732,7 +700,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
732
700
|
transform_kwargs = dict(
|
733
701
|
session=dataset._session,
|
734
702
|
dependencies=self._deps,
|
735
|
-
|
703
|
+
drop_input_cols = self._drop_input_cols,
|
736
704
|
expected_output_cols_type="float",
|
737
705
|
)
|
738
706
|
|
@@ -797,7 +765,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
797
765
|
transform_kwargs = dict(
|
798
766
|
session=dataset._session,
|
799
767
|
dependencies=self._deps,
|
800
|
-
|
768
|
+
drop_input_cols = self._drop_input_cols,
|
801
769
|
expected_output_cols_type="float",
|
802
770
|
)
|
803
771
|
|
@@ -853,13 +821,17 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
853
821
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
854
822
|
|
855
823
|
if isinstance(dataset, DataFrame):
|
824
|
+
self._deps = self._batch_inference_validate_snowpark(
|
825
|
+
dataset=dataset,
|
826
|
+
inference_method="score",
|
827
|
+
)
|
856
828
|
selected_cols = self._get_active_columns()
|
857
829
|
if len(selected_cols) > 0:
|
858
830
|
dataset = dataset.select(selected_cols)
|
859
831
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
860
832
|
transform_kwargs = dict(
|
861
833
|
session=dataset._session,
|
862
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
834
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
863
835
|
score_sproc_imports=['sklearn'],
|
864
836
|
)
|
865
837
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -933,9 +905,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
933
905
|
transform_kwargs = dict(
|
934
906
|
session = dataset._session,
|
935
907
|
dependencies = self._deps,
|
936
|
-
|
937
|
-
expected_output_cols_type
|
938
|
-
n_neighbors =
|
908
|
+
drop_input_cols = self._drop_input_cols,
|
909
|
+
expected_output_cols_type="array",
|
910
|
+
n_neighbors = n_neighbors,
|
939
911
|
return_distance = return_distance
|
940
912
|
)
|
941
913
|
elif isinstance(dataset, pd.DataFrame):
|