snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LocalOutlierFactor(BaseTransformer):
|
70
64
|
r"""Unsupervised Outlier Detection using the Local Outlier Factor (LOF)
|
71
65
|
For more details on this class, see [sklearn.neighbors.LocalOutlierFactor]
|
@@ -341,20 +335,17 @@ class LocalOutlierFactor(BaseTransformer):
|
|
341
335
|
self,
|
342
336
|
dataset: DataFrame,
|
343
337
|
inference_method: str,
|
344
|
-
) ->
|
345
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
346
|
-
return the available package that exists in the snowflake anaconda channel
|
338
|
+
) -> None:
|
339
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
347
340
|
|
348
341
|
Args:
|
349
342
|
dataset: snowpark dataframe
|
350
343
|
inference_method: the inference method such as predict, score...
|
351
|
-
|
344
|
+
|
352
345
|
Raises:
|
353
346
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
354
347
|
SnowflakeMLException: If the session is None, raise error
|
355
348
|
|
356
|
-
Returns:
|
357
|
-
A list of available package that exists in the snowflake anaconda channel
|
358
349
|
"""
|
359
350
|
if not self._is_fitted:
|
360
351
|
raise exceptions.SnowflakeMLException(
|
@@ -372,9 +363,7 @@ class LocalOutlierFactor(BaseTransformer):
|
|
372
363
|
"Session must not specified for snowpark dataset."
|
373
364
|
),
|
374
365
|
)
|
375
|
-
|
376
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
377
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
366
|
+
|
378
367
|
|
379
368
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
380
369
|
@telemetry.send_api_usage_telemetry(
|
@@ -422,7 +411,8 @@ class LocalOutlierFactor(BaseTransformer):
|
|
422
411
|
|
423
412
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
424
413
|
|
425
|
-
self.
|
414
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
415
|
+
self._deps = self._get_dependencies()
|
426
416
|
assert isinstance(
|
427
417
|
dataset._session, Session
|
428
418
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -505,10 +495,8 @@ class LocalOutlierFactor(BaseTransformer):
|
|
505
495
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
506
496
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
507
497
|
|
508
|
-
self.
|
509
|
-
|
510
|
-
inference_method=inference_method,
|
511
|
-
)
|
498
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
499
|
+
self._deps = self._get_dependencies()
|
512
500
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
513
501
|
|
514
502
|
transform_kwargs = dict(
|
@@ -577,16 +565,40 @@ class LocalOutlierFactor(BaseTransformer):
|
|
577
565
|
self._is_fitted = True
|
578
566
|
return output_result
|
579
567
|
|
568
|
+
|
569
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
570
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
571
|
+
""" Method not supported for this class.
|
572
|
+
|
580
573
|
|
581
|
-
|
582
|
-
|
583
|
-
|
574
|
+
Raises:
|
575
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
576
|
+
|
577
|
+
Args:
|
578
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
579
|
+
Snowpark or Pandas DataFrame.
|
580
|
+
output_cols_prefix: Prefix for the response columns
|
584
581
|
Returns:
|
585
582
|
Transformed dataset.
|
586
583
|
"""
|
587
|
-
self.
|
588
|
-
|
589
|
-
|
584
|
+
self._infer_input_output_cols(dataset)
|
585
|
+
super()._check_dataset_type(dataset)
|
586
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
587
|
+
estimator=self._sklearn_object,
|
588
|
+
dataset=dataset,
|
589
|
+
input_cols=self.input_cols,
|
590
|
+
label_cols=self.label_cols,
|
591
|
+
sample_weight_col=self.sample_weight_col,
|
592
|
+
autogenerated=self._autogenerated,
|
593
|
+
subproject=_SUBPROJECT,
|
594
|
+
)
|
595
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
596
|
+
drop_input_cols=self._drop_input_cols,
|
597
|
+
expected_output_cols_list=self.output_cols,
|
598
|
+
)
|
599
|
+
self._sklearn_object = fitted_estimator
|
600
|
+
self._is_fitted = True
|
601
|
+
return output_result
|
590
602
|
|
591
603
|
|
592
604
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -677,10 +689,8 @@ class LocalOutlierFactor(BaseTransformer):
|
|
677
689
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
678
690
|
|
679
691
|
if isinstance(dataset, DataFrame):
|
680
|
-
self.
|
681
|
-
|
682
|
-
inference_method=inference_method,
|
683
|
-
)
|
692
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
693
|
+
self._deps = self._get_dependencies()
|
684
694
|
assert isinstance(
|
685
695
|
dataset._session, Session
|
686
696
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -745,10 +755,8 @@ class LocalOutlierFactor(BaseTransformer):
|
|
745
755
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
746
756
|
|
747
757
|
if isinstance(dataset, DataFrame):
|
748
|
-
self.
|
749
|
-
|
750
|
-
inference_method=inference_method,
|
751
|
-
)
|
758
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
759
|
+
self._deps = self._get_dependencies()
|
752
760
|
assert isinstance(
|
753
761
|
dataset._session, Session
|
754
762
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -812,10 +820,8 @@ class LocalOutlierFactor(BaseTransformer):
|
|
812
820
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
813
821
|
|
814
822
|
if isinstance(dataset, DataFrame):
|
815
|
-
self.
|
816
|
-
|
817
|
-
inference_method=inference_method,
|
818
|
-
)
|
823
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
824
|
+
self._deps = self._get_dependencies()
|
819
825
|
assert isinstance(
|
820
826
|
dataset._session, Session
|
821
827
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -883,10 +889,8 @@ class LocalOutlierFactor(BaseTransformer):
|
|
883
889
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
884
890
|
|
885
891
|
if isinstance(dataset, DataFrame):
|
886
|
-
self.
|
887
|
-
|
888
|
-
inference_method=inference_method,
|
889
|
-
)
|
892
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
893
|
+
self._deps = self._get_dependencies()
|
890
894
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
891
895
|
transform_kwargs = dict(
|
892
896
|
session=dataset._session,
|
@@ -948,17 +952,15 @@ class LocalOutlierFactor(BaseTransformer):
|
|
948
952
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
949
953
|
|
950
954
|
if isinstance(dataset, DataFrame):
|
951
|
-
self.
|
952
|
-
|
953
|
-
inference_method="score",
|
954
|
-
)
|
955
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
956
|
+
self._deps = self._get_dependencies()
|
955
957
|
selected_cols = self._get_active_columns()
|
956
958
|
if len(selected_cols) > 0:
|
957
959
|
dataset = dataset.select(selected_cols)
|
958
960
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
959
961
|
transform_kwargs = dict(
|
960
962
|
session=dataset._session,
|
961
|
-
dependencies=
|
963
|
+
dependencies=self._deps,
|
962
964
|
score_sproc_imports=['sklearn'],
|
963
965
|
)
|
964
966
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1025,11 +1027,8 @@ class LocalOutlierFactor(BaseTransformer):
|
|
1025
1027
|
|
1026
1028
|
if isinstance(dataset, DataFrame):
|
1027
1029
|
|
1028
|
-
self.
|
1029
|
-
|
1030
|
-
inference_method=inference_method,
|
1031
|
-
|
1032
|
-
)
|
1030
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1031
|
+
self._deps = self._get_dependencies()
|
1033
1032
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1034
1033
|
transform_kwargs = dict(
|
1035
1034
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class NearestCentroid(BaseTransformer):
|
70
64
|
r"""Nearest centroid classifier
|
71
65
|
For more details on this class, see [sklearn.neighbors.NearestCentroid]
|
@@ -274,20 +268,17 @@ class NearestCentroid(BaseTransformer):
|
|
274
268
|
self,
|
275
269
|
dataset: DataFrame,
|
276
270
|
inference_method: str,
|
277
|
-
) ->
|
278
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
279
|
-
return the available package that exists in the snowflake anaconda channel
|
271
|
+
) -> None:
|
272
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
280
273
|
|
281
274
|
Args:
|
282
275
|
dataset: snowpark dataframe
|
283
276
|
inference_method: the inference method such as predict, score...
|
284
|
-
|
277
|
+
|
285
278
|
Raises:
|
286
279
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
287
280
|
SnowflakeMLException: If the session is None, raise error
|
288
281
|
|
289
|
-
Returns:
|
290
|
-
A list of available package that exists in the snowflake anaconda channel
|
291
282
|
"""
|
292
283
|
if not self._is_fitted:
|
293
284
|
raise exceptions.SnowflakeMLException(
|
@@ -305,9 +296,7 @@ class NearestCentroid(BaseTransformer):
|
|
305
296
|
"Session must not specified for snowpark dataset."
|
306
297
|
),
|
307
298
|
)
|
308
|
-
|
309
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
310
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
299
|
+
|
311
300
|
|
312
301
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
313
302
|
@telemetry.send_api_usage_telemetry(
|
@@ -355,7 +344,8 @@ class NearestCentroid(BaseTransformer):
|
|
355
344
|
|
356
345
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
357
346
|
|
358
|
-
self.
|
347
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
348
|
+
self._deps = self._get_dependencies()
|
359
349
|
assert isinstance(
|
360
350
|
dataset._session, Session
|
361
351
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -438,10 +428,8 @@ class NearestCentroid(BaseTransformer):
|
|
438
428
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
439
429
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
440
430
|
|
441
|
-
self.
|
442
|
-
|
443
|
-
inference_method=inference_method,
|
444
|
-
)
|
431
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
432
|
+
self._deps = self._get_dependencies()
|
445
433
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
446
434
|
|
447
435
|
transform_kwargs = dict(
|
@@ -508,16 +496,40 @@ class NearestCentroid(BaseTransformer):
|
|
508
496
|
self._is_fitted = True
|
509
497
|
return output_result
|
510
498
|
|
499
|
+
|
500
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
501
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
502
|
+
""" Method not supported for this class.
|
511
503
|
|
512
|
-
|
513
|
-
|
514
|
-
|
504
|
+
|
505
|
+
Raises:
|
506
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
507
|
+
|
508
|
+
Args:
|
509
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
510
|
+
Snowpark or Pandas DataFrame.
|
511
|
+
output_cols_prefix: Prefix for the response columns
|
515
512
|
Returns:
|
516
513
|
Transformed dataset.
|
517
514
|
"""
|
518
|
-
self.
|
519
|
-
|
520
|
-
|
515
|
+
self._infer_input_output_cols(dataset)
|
516
|
+
super()._check_dataset_type(dataset)
|
517
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
518
|
+
estimator=self._sklearn_object,
|
519
|
+
dataset=dataset,
|
520
|
+
input_cols=self.input_cols,
|
521
|
+
label_cols=self.label_cols,
|
522
|
+
sample_weight_col=self.sample_weight_col,
|
523
|
+
autogenerated=self._autogenerated,
|
524
|
+
subproject=_SUBPROJECT,
|
525
|
+
)
|
526
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
527
|
+
drop_input_cols=self._drop_input_cols,
|
528
|
+
expected_output_cols_list=self.output_cols,
|
529
|
+
)
|
530
|
+
self._sklearn_object = fitted_estimator
|
531
|
+
self._is_fitted = True
|
532
|
+
return output_result
|
521
533
|
|
522
534
|
|
523
535
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -608,10 +620,8 @@ class NearestCentroid(BaseTransformer):
|
|
608
620
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
609
621
|
|
610
622
|
if isinstance(dataset, DataFrame):
|
611
|
-
self.
|
612
|
-
|
613
|
-
inference_method=inference_method,
|
614
|
-
)
|
623
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
624
|
+
self._deps = self._get_dependencies()
|
615
625
|
assert isinstance(
|
616
626
|
dataset._session, Session
|
617
627
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -676,10 +686,8 @@ class NearestCentroid(BaseTransformer):
|
|
676
686
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
677
687
|
|
678
688
|
if isinstance(dataset, DataFrame):
|
679
|
-
self.
|
680
|
-
|
681
|
-
inference_method=inference_method,
|
682
|
-
)
|
689
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
690
|
+
self._deps = self._get_dependencies()
|
683
691
|
assert isinstance(
|
684
692
|
dataset._session, Session
|
685
693
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -741,10 +749,8 @@ class NearestCentroid(BaseTransformer):
|
|
741
749
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
742
750
|
|
743
751
|
if isinstance(dataset, DataFrame):
|
744
|
-
self.
|
745
|
-
|
746
|
-
inference_method=inference_method,
|
747
|
-
)
|
752
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
753
|
+
self._deps = self._get_dependencies()
|
748
754
|
assert isinstance(
|
749
755
|
dataset._session, Session
|
750
756
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -810,10 +816,8 @@ class NearestCentroid(BaseTransformer):
|
|
810
816
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
811
817
|
|
812
818
|
if isinstance(dataset, DataFrame):
|
813
|
-
self.
|
814
|
-
|
815
|
-
inference_method=inference_method,
|
816
|
-
)
|
819
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
820
|
+
self._deps = self._get_dependencies()
|
817
821
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
818
822
|
transform_kwargs = dict(
|
819
823
|
session=dataset._session,
|
@@ -877,17 +881,15 @@ class NearestCentroid(BaseTransformer):
|
|
877
881
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
878
882
|
|
879
883
|
if isinstance(dataset, DataFrame):
|
880
|
-
self.
|
881
|
-
|
882
|
-
inference_method="score",
|
883
|
-
)
|
884
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
885
|
+
self._deps = self._get_dependencies()
|
884
886
|
selected_cols = self._get_active_columns()
|
885
887
|
if len(selected_cols) > 0:
|
886
888
|
dataset = dataset.select(selected_cols)
|
887
889
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
888
890
|
transform_kwargs = dict(
|
889
891
|
session=dataset._session,
|
890
|
-
dependencies=
|
892
|
+
dependencies=self._deps,
|
891
893
|
score_sproc_imports=['sklearn'],
|
892
894
|
)
|
893
895
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -952,11 +954,8 @@ class NearestCentroid(BaseTransformer):
|
|
952
954
|
|
953
955
|
if isinstance(dataset, DataFrame):
|
954
956
|
|
955
|
-
self.
|
956
|
-
|
957
|
-
inference_method=inference_method,
|
958
|
-
|
959
|
-
)
|
957
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
958
|
+
self._deps = self._get_dependencies()
|
960
959
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
961
960
|
transform_kwargs = dict(
|
962
961
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class NearestNeighbors(BaseTransformer):
|
70
64
|
r"""Unsupervised learner for implementing neighbor searches
|
71
65
|
For more details on this class, see [sklearn.neighbors.NearestNeighbors]
|
@@ -324,20 +318,17 @@ class NearestNeighbors(BaseTransformer):
|
|
324
318
|
self,
|
325
319
|
dataset: DataFrame,
|
326
320
|
inference_method: str,
|
327
|
-
) ->
|
328
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
329
|
-
return the available package that exists in the snowflake anaconda channel
|
321
|
+
) -> None:
|
322
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
330
323
|
|
331
324
|
Args:
|
332
325
|
dataset: snowpark dataframe
|
333
326
|
inference_method: the inference method such as predict, score...
|
334
|
-
|
327
|
+
|
335
328
|
Raises:
|
336
329
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
337
330
|
SnowflakeMLException: If the session is None, raise error
|
338
331
|
|
339
|
-
Returns:
|
340
|
-
A list of available package that exists in the snowflake anaconda channel
|
341
332
|
"""
|
342
333
|
if not self._is_fitted:
|
343
334
|
raise exceptions.SnowflakeMLException(
|
@@ -355,9 +346,7 @@ class NearestNeighbors(BaseTransformer):
|
|
355
346
|
"Session must not specified for snowpark dataset."
|
356
347
|
),
|
357
348
|
)
|
358
|
-
|
359
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
360
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
349
|
+
|
361
350
|
|
362
351
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
363
352
|
@telemetry.send_api_usage_telemetry(
|
@@ -403,7 +392,8 @@ class NearestNeighbors(BaseTransformer):
|
|
403
392
|
|
404
393
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
405
394
|
|
406
|
-
self.
|
395
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
396
|
+
self._deps = self._get_dependencies()
|
407
397
|
assert isinstance(
|
408
398
|
dataset._session, Session
|
409
399
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -486,10 +476,8 @@ class NearestNeighbors(BaseTransformer):
|
|
486
476
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
487
477
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
488
478
|
|
489
|
-
self.
|
490
|
-
|
491
|
-
inference_method=inference_method,
|
492
|
-
)
|
479
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
480
|
+
self._deps = self._get_dependencies()
|
493
481
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
494
482
|
|
495
483
|
transform_kwargs = dict(
|
@@ -556,16 +544,40 @@ class NearestNeighbors(BaseTransformer):
|
|
556
544
|
self._is_fitted = True
|
557
545
|
return output_result
|
558
546
|
|
547
|
+
|
548
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
549
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
550
|
+
""" Method not supported for this class.
|
559
551
|
|
560
|
-
|
561
|
-
|
562
|
-
|
552
|
+
|
553
|
+
Raises:
|
554
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
558
|
+
Snowpark or Pandas DataFrame.
|
559
|
+
output_cols_prefix: Prefix for the response columns
|
563
560
|
Returns:
|
564
561
|
Transformed dataset.
|
565
562
|
"""
|
566
|
-
self.
|
567
|
-
|
568
|
-
|
563
|
+
self._infer_input_output_cols(dataset)
|
564
|
+
super()._check_dataset_type(dataset)
|
565
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
566
|
+
estimator=self._sklearn_object,
|
567
|
+
dataset=dataset,
|
568
|
+
input_cols=self.input_cols,
|
569
|
+
label_cols=self.label_cols,
|
570
|
+
sample_weight_col=self.sample_weight_col,
|
571
|
+
autogenerated=self._autogenerated,
|
572
|
+
subproject=_SUBPROJECT,
|
573
|
+
)
|
574
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
575
|
+
drop_input_cols=self._drop_input_cols,
|
576
|
+
expected_output_cols_list=self.output_cols,
|
577
|
+
)
|
578
|
+
self._sklearn_object = fitted_estimator
|
579
|
+
self._is_fitted = True
|
580
|
+
return output_result
|
569
581
|
|
570
582
|
|
571
583
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -656,10 +668,8 @@ class NearestNeighbors(BaseTransformer):
|
|
656
668
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
657
669
|
|
658
670
|
if isinstance(dataset, DataFrame):
|
659
|
-
self.
|
660
|
-
|
661
|
-
inference_method=inference_method,
|
662
|
-
)
|
671
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
672
|
+
self._deps = self._get_dependencies()
|
663
673
|
assert isinstance(
|
664
674
|
dataset._session, Session
|
665
675
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -724,10 +734,8 @@ class NearestNeighbors(BaseTransformer):
|
|
724
734
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
725
735
|
|
726
736
|
if isinstance(dataset, DataFrame):
|
727
|
-
self.
|
728
|
-
|
729
|
-
inference_method=inference_method,
|
730
|
-
)
|
737
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
738
|
+
self._deps = self._get_dependencies()
|
731
739
|
assert isinstance(
|
732
740
|
dataset._session, Session
|
733
741
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -789,10 +797,8 @@ class NearestNeighbors(BaseTransformer):
|
|
789
797
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
790
798
|
|
791
799
|
if isinstance(dataset, DataFrame):
|
792
|
-
self.
|
793
|
-
|
794
|
-
inference_method=inference_method,
|
795
|
-
)
|
800
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
801
|
+
self._deps = self._get_dependencies()
|
796
802
|
assert isinstance(
|
797
803
|
dataset._session, Session
|
798
804
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -858,10 +864,8 @@ class NearestNeighbors(BaseTransformer):
|
|
858
864
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
859
865
|
|
860
866
|
if isinstance(dataset, DataFrame):
|
861
|
-
self.
|
862
|
-
|
863
|
-
inference_method=inference_method,
|
864
|
-
)
|
867
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
868
|
+
self._deps = self._get_dependencies()
|
865
869
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
866
870
|
transform_kwargs = dict(
|
867
871
|
session=dataset._session,
|
@@ -923,17 +927,15 @@ class NearestNeighbors(BaseTransformer):
|
|
923
927
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
924
928
|
|
925
929
|
if isinstance(dataset, DataFrame):
|
926
|
-
self.
|
927
|
-
|
928
|
-
inference_method="score",
|
929
|
-
)
|
930
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
931
|
+
self._deps = self._get_dependencies()
|
930
932
|
selected_cols = self._get_active_columns()
|
931
933
|
if len(selected_cols) > 0:
|
932
934
|
dataset = dataset.select(selected_cols)
|
933
935
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
934
936
|
transform_kwargs = dict(
|
935
937
|
session=dataset._session,
|
936
|
-
dependencies=
|
938
|
+
dependencies=self._deps,
|
937
939
|
score_sproc_imports=['sklearn'],
|
938
940
|
)
|
939
941
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1000,11 +1002,8 @@ class NearestNeighbors(BaseTransformer):
|
|
1000
1002
|
|
1001
1003
|
if isinstance(dataset, DataFrame):
|
1002
1004
|
|
1003
|
-
self.
|
1004
|
-
|
1005
|
-
inference_method=inference_method,
|
1006
|
-
|
1007
|
-
)
|
1005
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1006
|
+
self._deps = self._get_dependencies()
|
1008
1007
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1009
1008
|
transform_kwargs = dict(
|
1010
1009
|
session = dataset._session,
|