snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -449,18 +449,24 @@ class SGDClassifier(BaseTransformer):
|
|
449
449
|
self._get_model_signatures(dataset)
|
450
450
|
return self
|
451
451
|
|
452
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
453
|
-
if self._drop_input_cols:
|
454
|
-
return []
|
455
|
-
else:
|
456
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
457
|
-
|
458
452
|
def _batch_inference_validate_snowpark(
|
459
453
|
self,
|
460
454
|
dataset: DataFrame,
|
461
455
|
inference_method: str,
|
462
456
|
) -> List[str]:
|
463
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
457
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
458
|
+
return the available package that exists in the snowflake anaconda channel
|
459
|
+
|
460
|
+
Args:
|
461
|
+
dataset: snowpark dataframe
|
462
|
+
inference_method: the inference method such as predict, score...
|
463
|
+
|
464
|
+
Raises:
|
465
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
466
|
+
SnowflakeMLException: If the session is None, raise error
|
467
|
+
|
468
|
+
Returns:
|
469
|
+
A list of available package that exists in the snowflake anaconda channel
|
464
470
|
"""
|
465
471
|
if not self._is_fitted:
|
466
472
|
raise exceptions.SnowflakeMLException(
|
@@ -534,7 +540,7 @@ class SGDClassifier(BaseTransformer):
|
|
534
540
|
transform_kwargs = dict(
|
535
541
|
session = dataset._session,
|
536
542
|
dependencies = self._deps,
|
537
|
-
|
543
|
+
drop_input_cols = self._drop_input_cols,
|
538
544
|
expected_output_cols_type = expected_type_inferred,
|
539
545
|
)
|
540
546
|
|
@@ -594,16 +600,16 @@ class SGDClassifier(BaseTransformer):
|
|
594
600
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
595
601
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
596
602
|
# each row containing a list of values.
|
597
|
-
expected_dtype = "
|
603
|
+
expected_dtype = "array"
|
598
604
|
|
599
605
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
600
606
|
if expected_dtype == "":
|
601
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
607
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
602
608
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
603
|
-
expected_dtype = "
|
604
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
609
|
+
expected_dtype = "array"
|
610
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
605
611
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
606
|
-
expected_dtype = "
|
612
|
+
expected_dtype = "array"
|
607
613
|
else:
|
608
614
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
609
615
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -621,7 +627,7 @@ class SGDClassifier(BaseTransformer):
|
|
621
627
|
transform_kwargs = dict(
|
622
628
|
session = dataset._session,
|
623
629
|
dependencies = self._deps,
|
624
|
-
|
630
|
+
drop_input_cols = self._drop_input_cols,
|
625
631
|
expected_output_cols_type = expected_dtype,
|
626
632
|
)
|
627
633
|
|
@@ -672,7 +678,7 @@ class SGDClassifier(BaseTransformer):
|
|
672
678
|
subproject=_SUBPROJECT,
|
673
679
|
)
|
674
680
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
675
|
-
|
681
|
+
drop_input_cols=self._drop_input_cols,
|
676
682
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
677
683
|
)
|
678
684
|
self._sklearn_object = fitted_estimator
|
@@ -690,44 +696,6 @@ class SGDClassifier(BaseTransformer):
|
|
690
696
|
assert self._sklearn_object is not None
|
691
697
|
return self._sklearn_object.embedding_
|
692
698
|
|
693
|
-
|
694
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
695
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
696
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
697
|
-
"""
|
698
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
699
|
-
if output_cols:
|
700
|
-
output_cols = [
|
701
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
702
|
-
for c in output_cols
|
703
|
-
]
|
704
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
705
|
-
output_cols = [output_cols_prefix]
|
706
|
-
elif self._sklearn_object is not None:
|
707
|
-
classes = self._sklearn_object.classes_
|
708
|
-
if isinstance(classes, numpy.ndarray):
|
709
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
710
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
711
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
712
|
-
output_cols = []
|
713
|
-
for i, cl in enumerate(classes):
|
714
|
-
# For binary classification, there is only one output column for each class
|
715
|
-
# ndarray as the two classes are complementary.
|
716
|
-
if len(cl) == 2:
|
717
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
718
|
-
else:
|
719
|
-
output_cols.extend([
|
720
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
721
|
-
])
|
722
|
-
else:
|
723
|
-
output_cols = []
|
724
|
-
|
725
|
-
# Make sure column names are valid snowflake identifiers.
|
726
|
-
assert output_cols is not None # Make MyPy happy
|
727
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
728
|
-
|
729
|
-
return rv
|
730
|
-
|
731
699
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
732
700
|
@telemetry.send_api_usage_telemetry(
|
733
701
|
project=_PROJECT,
|
@@ -769,7 +737,7 @@ class SGDClassifier(BaseTransformer):
|
|
769
737
|
transform_kwargs = dict(
|
770
738
|
session=dataset._session,
|
771
739
|
dependencies=self._deps,
|
772
|
-
|
740
|
+
drop_input_cols = self._drop_input_cols,
|
773
741
|
expected_output_cols_type="float",
|
774
742
|
)
|
775
743
|
|
@@ -836,7 +804,7 @@ class SGDClassifier(BaseTransformer):
|
|
836
804
|
transform_kwargs = dict(
|
837
805
|
session=dataset._session,
|
838
806
|
dependencies=self._deps,
|
839
|
-
|
807
|
+
drop_input_cols = self._drop_input_cols,
|
840
808
|
expected_output_cols_type="float",
|
841
809
|
)
|
842
810
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -899,7 +867,7 @@ class SGDClassifier(BaseTransformer):
|
|
899
867
|
transform_kwargs = dict(
|
900
868
|
session=dataset._session,
|
901
869
|
dependencies=self._deps,
|
902
|
-
|
870
|
+
drop_input_cols = self._drop_input_cols,
|
903
871
|
expected_output_cols_type="float",
|
904
872
|
)
|
905
873
|
|
@@ -964,7 +932,7 @@ class SGDClassifier(BaseTransformer):
|
|
964
932
|
transform_kwargs = dict(
|
965
933
|
session=dataset._session,
|
966
934
|
dependencies=self._deps,
|
967
|
-
|
935
|
+
drop_input_cols = self._drop_input_cols,
|
968
936
|
expected_output_cols_type="float",
|
969
937
|
)
|
970
938
|
|
@@ -1020,13 +988,17 @@ class SGDClassifier(BaseTransformer):
|
|
1020
988
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1021
989
|
|
1022
990
|
if isinstance(dataset, DataFrame):
|
991
|
+
self._deps = self._batch_inference_validate_snowpark(
|
992
|
+
dataset=dataset,
|
993
|
+
inference_method="score",
|
994
|
+
)
|
1023
995
|
selected_cols = self._get_active_columns()
|
1024
996
|
if len(selected_cols) > 0:
|
1025
997
|
dataset = dataset.select(selected_cols)
|
1026
998
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1027
999
|
transform_kwargs = dict(
|
1028
1000
|
session=dataset._session,
|
1029
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
1001
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
1030
1002
|
score_sproc_imports=['sklearn'],
|
1031
1003
|
)
|
1032
1004
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1100,9 +1072,9 @@ class SGDClassifier(BaseTransformer):
|
|
1100
1072
|
transform_kwargs = dict(
|
1101
1073
|
session = dataset._session,
|
1102
1074
|
dependencies = self._deps,
|
1103
|
-
|
1104
|
-
expected_output_cols_type
|
1105
|
-
n_neighbors =
|
1075
|
+
drop_input_cols = self._drop_input_cols,
|
1076
|
+
expected_output_cols_type="array",
|
1077
|
+
n_neighbors = n_neighbors,
|
1106
1078
|
return_distance = return_distance
|
1107
1079
|
)
|
1108
1080
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -347,18 +347,24 @@ class SGDOneClassSVM(BaseTransformer):
|
|
347
347
|
self._get_model_signatures(dataset)
|
348
348
|
return self
|
349
349
|
|
350
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
351
|
-
if self._drop_input_cols:
|
352
|
-
return []
|
353
|
-
else:
|
354
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
355
|
-
|
356
350
|
def _batch_inference_validate_snowpark(
|
357
351
|
self,
|
358
352
|
dataset: DataFrame,
|
359
353
|
inference_method: str,
|
360
354
|
) -> List[str]:
|
361
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
355
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
356
|
+
return the available package that exists in the snowflake anaconda channel
|
357
|
+
|
358
|
+
Args:
|
359
|
+
dataset: snowpark dataframe
|
360
|
+
inference_method: the inference method such as predict, score...
|
361
|
+
|
362
|
+
Raises:
|
363
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
364
|
+
SnowflakeMLException: If the session is None, raise error
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
A list of available package that exists in the snowflake anaconda channel
|
362
368
|
"""
|
363
369
|
if not self._is_fitted:
|
364
370
|
raise exceptions.SnowflakeMLException(
|
@@ -432,7 +438,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
432
438
|
transform_kwargs = dict(
|
433
439
|
session = dataset._session,
|
434
440
|
dependencies = self._deps,
|
435
|
-
|
441
|
+
drop_input_cols = self._drop_input_cols,
|
436
442
|
expected_output_cols_type = expected_type_inferred,
|
437
443
|
)
|
438
444
|
|
@@ -492,16 +498,16 @@ class SGDOneClassSVM(BaseTransformer):
|
|
492
498
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
493
499
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
494
500
|
# each row containing a list of values.
|
495
|
-
expected_dtype = "
|
501
|
+
expected_dtype = "array"
|
496
502
|
|
497
503
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
498
504
|
if expected_dtype == "":
|
499
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
505
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
500
506
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
501
|
-
expected_dtype = "
|
502
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
507
|
+
expected_dtype = "array"
|
508
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
503
509
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
504
|
-
expected_dtype = "
|
510
|
+
expected_dtype = "array"
|
505
511
|
else:
|
506
512
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
507
513
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -519,7 +525,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
519
525
|
transform_kwargs = dict(
|
520
526
|
session = dataset._session,
|
521
527
|
dependencies = self._deps,
|
522
|
-
|
528
|
+
drop_input_cols = self._drop_input_cols,
|
523
529
|
expected_output_cols_type = expected_dtype,
|
524
530
|
)
|
525
531
|
|
@@ -572,7 +578,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
572
578
|
subproject=_SUBPROJECT,
|
573
579
|
)
|
574
580
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
575
|
-
|
581
|
+
drop_input_cols=self._drop_input_cols,
|
576
582
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
577
583
|
)
|
578
584
|
self._sklearn_object = fitted_estimator
|
@@ -590,44 +596,6 @@ class SGDOneClassSVM(BaseTransformer):
|
|
590
596
|
assert self._sklearn_object is not None
|
591
597
|
return self._sklearn_object.embedding_
|
592
598
|
|
593
|
-
|
594
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
595
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
596
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
597
|
-
"""
|
598
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
599
|
-
if output_cols:
|
600
|
-
output_cols = [
|
601
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
602
|
-
for c in output_cols
|
603
|
-
]
|
604
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
605
|
-
output_cols = [output_cols_prefix]
|
606
|
-
elif self._sklearn_object is not None:
|
607
|
-
classes = self._sklearn_object.classes_
|
608
|
-
if isinstance(classes, numpy.ndarray):
|
609
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
610
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
611
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
612
|
-
output_cols = []
|
613
|
-
for i, cl in enumerate(classes):
|
614
|
-
# For binary classification, there is only one output column for each class
|
615
|
-
# ndarray as the two classes are complementary.
|
616
|
-
if len(cl) == 2:
|
617
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
618
|
-
else:
|
619
|
-
output_cols.extend([
|
620
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
621
|
-
])
|
622
|
-
else:
|
623
|
-
output_cols = []
|
624
|
-
|
625
|
-
# Make sure column names are valid snowflake identifiers.
|
626
|
-
assert output_cols is not None # Make MyPy happy
|
627
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
628
|
-
|
629
|
-
return rv
|
630
|
-
|
631
599
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
632
600
|
@telemetry.send_api_usage_telemetry(
|
633
601
|
project=_PROJECT,
|
@@ -667,7 +635,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
667
635
|
transform_kwargs = dict(
|
668
636
|
session=dataset._session,
|
669
637
|
dependencies=self._deps,
|
670
|
-
|
638
|
+
drop_input_cols = self._drop_input_cols,
|
671
639
|
expected_output_cols_type="float",
|
672
640
|
)
|
673
641
|
|
@@ -732,7 +700,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
732
700
|
transform_kwargs = dict(
|
733
701
|
session=dataset._session,
|
734
702
|
dependencies=self._deps,
|
735
|
-
|
703
|
+
drop_input_cols = self._drop_input_cols,
|
736
704
|
expected_output_cols_type="float",
|
737
705
|
)
|
738
706
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -795,7 +763,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
795
763
|
transform_kwargs = dict(
|
796
764
|
session=dataset._session,
|
797
765
|
dependencies=self._deps,
|
798
|
-
|
766
|
+
drop_input_cols = self._drop_input_cols,
|
799
767
|
expected_output_cols_type="float",
|
800
768
|
)
|
801
769
|
|
@@ -862,7 +830,7 @@ class SGDOneClassSVM(BaseTransformer):
|
|
862
830
|
transform_kwargs = dict(
|
863
831
|
session=dataset._session,
|
864
832
|
dependencies=self._deps,
|
865
|
-
|
833
|
+
drop_input_cols = self._drop_input_cols,
|
866
834
|
expected_output_cols_type="float",
|
867
835
|
)
|
868
836
|
|
@@ -916,13 +884,17 @@ class SGDOneClassSVM(BaseTransformer):
|
|
916
884
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
917
885
|
|
918
886
|
if isinstance(dataset, DataFrame):
|
887
|
+
self._deps = self._batch_inference_validate_snowpark(
|
888
|
+
dataset=dataset,
|
889
|
+
inference_method="score",
|
890
|
+
)
|
919
891
|
selected_cols = self._get_active_columns()
|
920
892
|
if len(selected_cols) > 0:
|
921
893
|
dataset = dataset.select(selected_cols)
|
922
894
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
923
895
|
transform_kwargs = dict(
|
924
896
|
session=dataset._session,
|
925
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
897
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
926
898
|
score_sproc_imports=['sklearn'],
|
927
899
|
)
|
928
900
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -996,9 +968,9 @@ class SGDOneClassSVM(BaseTransformer):
|
|
996
968
|
transform_kwargs = dict(
|
997
969
|
session = dataset._session,
|
998
970
|
dependencies = self._deps,
|
999
|
-
|
1000
|
-
expected_output_cols_type
|
1001
|
-
n_neighbors =
|
971
|
+
drop_input_cols = self._drop_input_cols,
|
972
|
+
expected_output_cols_type="array",
|
973
|
+
n_neighbors = n_neighbors,
|
1002
974
|
return_distance = return_distance
|
1003
975
|
)
|
1004
976
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -415,18 +415,24 @@ class SGDRegressor(BaseTransformer):
|
|
415
415
|
self._get_model_signatures(dataset)
|
416
416
|
return self
|
417
417
|
|
418
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
419
|
-
if self._drop_input_cols:
|
420
|
-
return []
|
421
|
-
else:
|
422
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
423
|
-
|
424
418
|
def _batch_inference_validate_snowpark(
|
425
419
|
self,
|
426
420
|
dataset: DataFrame,
|
427
421
|
inference_method: str,
|
428
422
|
) -> List[str]:
|
429
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
423
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
424
|
+
return the available package that exists in the snowflake anaconda channel
|
425
|
+
|
426
|
+
Args:
|
427
|
+
dataset: snowpark dataframe
|
428
|
+
inference_method: the inference method such as predict, score...
|
429
|
+
|
430
|
+
Raises:
|
431
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
432
|
+
SnowflakeMLException: If the session is None, raise error
|
433
|
+
|
434
|
+
Returns:
|
435
|
+
A list of available package that exists in the snowflake anaconda channel
|
430
436
|
"""
|
431
437
|
if not self._is_fitted:
|
432
438
|
raise exceptions.SnowflakeMLException(
|
@@ -500,7 +506,7 @@ class SGDRegressor(BaseTransformer):
|
|
500
506
|
transform_kwargs = dict(
|
501
507
|
session = dataset._session,
|
502
508
|
dependencies = self._deps,
|
503
|
-
|
509
|
+
drop_input_cols = self._drop_input_cols,
|
504
510
|
expected_output_cols_type = expected_type_inferred,
|
505
511
|
)
|
506
512
|
|
@@ -560,16 +566,16 @@ class SGDRegressor(BaseTransformer):
|
|
560
566
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
561
567
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
562
568
|
# each row containing a list of values.
|
563
|
-
expected_dtype = "
|
569
|
+
expected_dtype = "array"
|
564
570
|
|
565
571
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
566
572
|
if expected_dtype == "":
|
567
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
573
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
568
574
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
569
|
-
expected_dtype = "
|
570
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
575
|
+
expected_dtype = "array"
|
576
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
571
577
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
572
|
-
expected_dtype = "
|
578
|
+
expected_dtype = "array"
|
573
579
|
else:
|
574
580
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
575
581
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -587,7 +593,7 @@ class SGDRegressor(BaseTransformer):
|
|
587
593
|
transform_kwargs = dict(
|
588
594
|
session = dataset._session,
|
589
595
|
dependencies = self._deps,
|
590
|
-
|
596
|
+
drop_input_cols = self._drop_input_cols,
|
591
597
|
expected_output_cols_type = expected_dtype,
|
592
598
|
)
|
593
599
|
|
@@ -638,7 +644,7 @@ class SGDRegressor(BaseTransformer):
|
|
638
644
|
subproject=_SUBPROJECT,
|
639
645
|
)
|
640
646
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
641
|
-
|
647
|
+
drop_input_cols=self._drop_input_cols,
|
642
648
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
643
649
|
)
|
644
650
|
self._sklearn_object = fitted_estimator
|
@@ -656,44 +662,6 @@ class SGDRegressor(BaseTransformer):
|
|
656
662
|
assert self._sklearn_object is not None
|
657
663
|
return self._sklearn_object.embedding_
|
658
664
|
|
659
|
-
|
660
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
661
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
662
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
663
|
-
"""
|
664
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
665
|
-
if output_cols:
|
666
|
-
output_cols = [
|
667
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
668
|
-
for c in output_cols
|
669
|
-
]
|
670
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
671
|
-
output_cols = [output_cols_prefix]
|
672
|
-
elif self._sklearn_object is not None:
|
673
|
-
classes = self._sklearn_object.classes_
|
674
|
-
if isinstance(classes, numpy.ndarray):
|
675
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
676
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
677
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
678
|
-
output_cols = []
|
679
|
-
for i, cl in enumerate(classes):
|
680
|
-
# For binary classification, there is only one output column for each class
|
681
|
-
# ndarray as the two classes are complementary.
|
682
|
-
if len(cl) == 2:
|
683
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
684
|
-
else:
|
685
|
-
output_cols.extend([
|
686
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
687
|
-
])
|
688
|
-
else:
|
689
|
-
output_cols = []
|
690
|
-
|
691
|
-
# Make sure column names are valid snowflake identifiers.
|
692
|
-
assert output_cols is not None # Make MyPy happy
|
693
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
694
|
-
|
695
|
-
return rv
|
696
|
-
|
697
665
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
698
666
|
@telemetry.send_api_usage_telemetry(
|
699
667
|
project=_PROJECT,
|
@@ -733,7 +701,7 @@ class SGDRegressor(BaseTransformer):
|
|
733
701
|
transform_kwargs = dict(
|
734
702
|
session=dataset._session,
|
735
703
|
dependencies=self._deps,
|
736
|
-
|
704
|
+
drop_input_cols = self._drop_input_cols,
|
737
705
|
expected_output_cols_type="float",
|
738
706
|
)
|
739
707
|
|
@@ -798,7 +766,7 @@ class SGDRegressor(BaseTransformer):
|
|
798
766
|
transform_kwargs = dict(
|
799
767
|
session=dataset._session,
|
800
768
|
dependencies=self._deps,
|
801
|
-
|
769
|
+
drop_input_cols = self._drop_input_cols,
|
802
770
|
expected_output_cols_type="float",
|
803
771
|
)
|
804
772
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -859,7 +827,7 @@ class SGDRegressor(BaseTransformer):
|
|
859
827
|
transform_kwargs = dict(
|
860
828
|
session=dataset._session,
|
861
829
|
dependencies=self._deps,
|
862
|
-
|
830
|
+
drop_input_cols = self._drop_input_cols,
|
863
831
|
expected_output_cols_type="float",
|
864
832
|
)
|
865
833
|
|
@@ -924,7 +892,7 @@ class SGDRegressor(BaseTransformer):
|
|
924
892
|
transform_kwargs = dict(
|
925
893
|
session=dataset._session,
|
926
894
|
dependencies=self._deps,
|
927
|
-
|
895
|
+
drop_input_cols = self._drop_input_cols,
|
928
896
|
expected_output_cols_type="float",
|
929
897
|
)
|
930
898
|
|
@@ -980,13 +948,17 @@ class SGDRegressor(BaseTransformer):
|
|
980
948
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
981
949
|
|
982
950
|
if isinstance(dataset, DataFrame):
|
951
|
+
self._deps = self._batch_inference_validate_snowpark(
|
952
|
+
dataset=dataset,
|
953
|
+
inference_method="score",
|
954
|
+
)
|
983
955
|
selected_cols = self._get_active_columns()
|
984
956
|
if len(selected_cols) > 0:
|
985
957
|
dataset = dataset.select(selected_cols)
|
986
958
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
987
959
|
transform_kwargs = dict(
|
988
960
|
session=dataset._session,
|
989
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
961
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
990
962
|
score_sproc_imports=['sklearn'],
|
991
963
|
)
|
992
964
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1060,9 +1032,9 @@ class SGDRegressor(BaseTransformer):
|
|
1060
1032
|
transform_kwargs = dict(
|
1061
1033
|
session = dataset._session,
|
1062
1034
|
dependencies = self._deps,
|
1063
|
-
|
1064
|
-
expected_output_cols_type
|
1065
|
-
n_neighbors =
|
1035
|
+
drop_input_cols = self._drop_input_cols,
|
1036
|
+
expected_output_cols_type="array",
|
1037
|
+
n_neighbors = n_neighbors,
|
1066
1038
|
return_distance = return_distance
|
1067
1039
|
)
|
1068
1040
|
elif isinstance(dataset, pd.DataFrame):
|