snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class NeighborhoodComponentsAnalysis(BaseTransformer):
|
70
64
|
r"""Neighborhood Components Analysis
|
71
65
|
For more details on this class, see [sklearn.neighbors.NeighborhoodComponentsAnalysis]
|
@@ -345,20 +339,17 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
345
339
|
self,
|
346
340
|
dataset: DataFrame,
|
347
341
|
inference_method: str,
|
348
|
-
) ->
|
349
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
350
|
-
return the available package that exists in the snowflake anaconda channel
|
342
|
+
) -> None:
|
343
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
351
344
|
|
352
345
|
Args:
|
353
346
|
dataset: snowpark dataframe
|
354
347
|
inference_method: the inference method such as predict, score...
|
355
|
-
|
348
|
+
|
356
349
|
Raises:
|
357
350
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
358
351
|
SnowflakeMLException: If the session is None, raise error
|
359
352
|
|
360
|
-
Returns:
|
361
|
-
A list of available package that exists in the snowflake anaconda channel
|
362
353
|
"""
|
363
354
|
if not self._is_fitted:
|
364
355
|
raise exceptions.SnowflakeMLException(
|
@@ -376,9 +367,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
376
367
|
"Session must not specified for snowpark dataset."
|
377
368
|
),
|
378
369
|
)
|
379
|
-
|
380
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
381
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
370
|
+
|
382
371
|
|
383
372
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
384
373
|
@telemetry.send_api_usage_telemetry(
|
@@ -424,7 +413,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
424
413
|
|
425
414
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
426
415
|
|
427
|
-
self.
|
416
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
417
|
+
self._deps = self._get_dependencies()
|
428
418
|
assert isinstance(
|
429
419
|
dataset._session, Session
|
430
420
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -509,10 +499,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
509
499
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
510
500
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
511
501
|
|
512
|
-
self.
|
513
|
-
|
514
|
-
inference_method=inference_method,
|
515
|
-
)
|
502
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
503
|
+
self._deps = self._get_dependencies()
|
516
504
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
517
505
|
|
518
506
|
transform_kwargs = dict(
|
@@ -579,16 +567,42 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
579
567
|
self._is_fitted = True
|
580
568
|
return output_result
|
581
569
|
|
570
|
+
|
571
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
572
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
573
|
+
""" Fit to data, then transform it
|
574
|
+
For more details on this function, see [sklearn.neighbors.NeighborhoodComponentsAnalysis.fit_transform]
|
575
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NeighborhoodComponentsAnalysis.html#sklearn.neighbors.NeighborhoodComponentsAnalysis.fit_transform)
|
576
|
+
|
582
577
|
|
583
|
-
|
584
|
-
|
585
|
-
|
578
|
+
Raises:
|
579
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
583
|
+
Snowpark or Pandas DataFrame.
|
584
|
+
output_cols_prefix: Prefix for the response columns
|
586
585
|
Returns:
|
587
586
|
Transformed dataset.
|
588
587
|
"""
|
589
|
-
self.
|
590
|
-
|
591
|
-
|
588
|
+
self._infer_input_output_cols(dataset)
|
589
|
+
super()._check_dataset_type(dataset)
|
590
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
591
|
+
estimator=self._sklearn_object,
|
592
|
+
dataset=dataset,
|
593
|
+
input_cols=self.input_cols,
|
594
|
+
label_cols=self.label_cols,
|
595
|
+
sample_weight_col=self.sample_weight_col,
|
596
|
+
autogenerated=self._autogenerated,
|
597
|
+
subproject=_SUBPROJECT,
|
598
|
+
)
|
599
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
600
|
+
drop_input_cols=self._drop_input_cols,
|
601
|
+
expected_output_cols_list=self.output_cols,
|
602
|
+
)
|
603
|
+
self._sklearn_object = fitted_estimator
|
604
|
+
self._is_fitted = True
|
605
|
+
return output_result
|
592
606
|
|
593
607
|
|
594
608
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -679,10 +693,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
679
693
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
680
694
|
|
681
695
|
if isinstance(dataset, DataFrame):
|
682
|
-
self.
|
683
|
-
|
684
|
-
inference_method=inference_method,
|
685
|
-
)
|
696
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
697
|
+
self._deps = self._get_dependencies()
|
686
698
|
assert isinstance(
|
687
699
|
dataset._session, Session
|
688
700
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -747,10 +759,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
747
759
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
748
760
|
|
749
761
|
if isinstance(dataset, DataFrame):
|
750
|
-
self.
|
751
|
-
|
752
|
-
inference_method=inference_method,
|
753
|
-
)
|
762
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
763
|
+
self._deps = self._get_dependencies()
|
754
764
|
assert isinstance(
|
755
765
|
dataset._session, Session
|
756
766
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -812,10 +822,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
812
822
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
813
823
|
|
814
824
|
if isinstance(dataset, DataFrame):
|
815
|
-
self.
|
816
|
-
|
817
|
-
inference_method=inference_method,
|
818
|
-
)
|
825
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
826
|
+
self._deps = self._get_dependencies()
|
819
827
|
assert isinstance(
|
820
828
|
dataset._session, Session
|
821
829
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -881,10 +889,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
881
889
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
882
890
|
|
883
891
|
if isinstance(dataset, DataFrame):
|
884
|
-
self.
|
885
|
-
|
886
|
-
inference_method=inference_method,
|
887
|
-
)
|
892
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
893
|
+
self._deps = self._get_dependencies()
|
888
894
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
889
895
|
transform_kwargs = dict(
|
890
896
|
session=dataset._session,
|
@@ -946,17 +952,15 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
946
952
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
947
953
|
|
948
954
|
if isinstance(dataset, DataFrame):
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method="score",
|
952
|
-
)
|
955
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
956
|
+
self._deps = self._get_dependencies()
|
953
957
|
selected_cols = self._get_active_columns()
|
954
958
|
if len(selected_cols) > 0:
|
955
959
|
dataset = dataset.select(selected_cols)
|
956
960
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
957
961
|
transform_kwargs = dict(
|
958
962
|
session=dataset._session,
|
959
|
-
dependencies=
|
963
|
+
dependencies=self._deps,
|
960
964
|
score_sproc_imports=['sklearn'],
|
961
965
|
)
|
962
966
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1021,11 +1025,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
1021
1025
|
|
1022
1026
|
if isinstance(dataset, DataFrame):
|
1023
1027
|
|
1024
|
-
self.
|
1025
|
-
|
1026
|
-
inference_method=inference_method,
|
1027
|
-
|
1028
|
-
)
|
1028
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1029
|
+
self._deps = self._get_dependencies()
|
1029
1030
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1030
1031
|
transform_kwargs = dict(
|
1031
1032
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RadiusNeighborsClassifier(BaseTransformer):
|
70
64
|
r"""Classifier implementing a vote among neighbors within a given radius
|
71
65
|
For more details on this class, see [sklearn.neighbors.RadiusNeighborsClassifier]
|
@@ -346,20 +340,17 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
346
340
|
self,
|
347
341
|
dataset: DataFrame,
|
348
342
|
inference_method: str,
|
349
|
-
) ->
|
350
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
351
|
-
return the available package that exists in the snowflake anaconda channel
|
343
|
+
) -> None:
|
344
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
352
345
|
|
353
346
|
Args:
|
354
347
|
dataset: snowpark dataframe
|
355
348
|
inference_method: the inference method such as predict, score...
|
356
|
-
|
349
|
+
|
357
350
|
Raises:
|
358
351
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
359
352
|
SnowflakeMLException: If the session is None, raise error
|
360
353
|
|
361
|
-
Returns:
|
362
|
-
A list of available package that exists in the snowflake anaconda channel
|
363
354
|
"""
|
364
355
|
if not self._is_fitted:
|
365
356
|
raise exceptions.SnowflakeMLException(
|
@@ -377,9 +368,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
377
368
|
"Session must not specified for snowpark dataset."
|
378
369
|
),
|
379
370
|
)
|
380
|
-
|
381
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
382
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
371
|
+
|
383
372
|
|
384
373
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
385
374
|
@telemetry.send_api_usage_telemetry(
|
@@ -427,7 +416,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
427
416
|
|
428
417
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
429
418
|
|
430
|
-
self.
|
419
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
420
|
+
self._deps = self._get_dependencies()
|
431
421
|
assert isinstance(
|
432
422
|
dataset._session, Session
|
433
423
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -510,10 +500,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
510
500
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
511
501
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
512
502
|
|
513
|
-
self.
|
514
|
-
|
515
|
-
inference_method=inference_method,
|
516
|
-
)
|
503
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
504
|
+
self._deps = self._get_dependencies()
|
517
505
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
518
506
|
|
519
507
|
transform_kwargs = dict(
|
@@ -580,16 +568,40 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
580
568
|
self._is_fitted = True
|
581
569
|
return output_result
|
582
570
|
|
571
|
+
|
572
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
573
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
574
|
+
""" Method not supported for this class.
|
583
575
|
|
584
|
-
|
585
|
-
|
586
|
-
|
576
|
+
|
577
|
+
Raises:
|
578
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
579
|
+
|
580
|
+
Args:
|
581
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
582
|
+
Snowpark or Pandas DataFrame.
|
583
|
+
output_cols_prefix: Prefix for the response columns
|
587
584
|
Returns:
|
588
585
|
Transformed dataset.
|
589
586
|
"""
|
590
|
-
self.
|
591
|
-
|
592
|
-
|
587
|
+
self._infer_input_output_cols(dataset)
|
588
|
+
super()._check_dataset_type(dataset)
|
589
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
590
|
+
estimator=self._sklearn_object,
|
591
|
+
dataset=dataset,
|
592
|
+
input_cols=self.input_cols,
|
593
|
+
label_cols=self.label_cols,
|
594
|
+
sample_weight_col=self.sample_weight_col,
|
595
|
+
autogenerated=self._autogenerated,
|
596
|
+
subproject=_SUBPROJECT,
|
597
|
+
)
|
598
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
599
|
+
drop_input_cols=self._drop_input_cols,
|
600
|
+
expected_output_cols_list=self.output_cols,
|
601
|
+
)
|
602
|
+
self._sklearn_object = fitted_estimator
|
603
|
+
self._is_fitted = True
|
604
|
+
return output_result
|
593
605
|
|
594
606
|
|
595
607
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -682,10 +694,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
682
694
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
683
695
|
|
684
696
|
if isinstance(dataset, DataFrame):
|
685
|
-
self.
|
686
|
-
|
687
|
-
inference_method=inference_method,
|
688
|
-
)
|
697
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
698
|
+
self._deps = self._get_dependencies()
|
689
699
|
assert isinstance(
|
690
700
|
dataset._session, Session
|
691
701
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -752,10 +762,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
752
762
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
753
763
|
|
754
764
|
if isinstance(dataset, DataFrame):
|
755
|
-
self.
|
756
|
-
|
757
|
-
inference_method=inference_method,
|
758
|
-
)
|
765
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
766
|
+
self._deps = self._get_dependencies()
|
759
767
|
assert isinstance(
|
760
768
|
dataset._session, Session
|
761
769
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -817,10 +825,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
817
825
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
818
826
|
|
819
827
|
if isinstance(dataset, DataFrame):
|
820
|
-
self.
|
821
|
-
|
822
|
-
inference_method=inference_method,
|
823
|
-
)
|
828
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
829
|
+
self._deps = self._get_dependencies()
|
824
830
|
assert isinstance(
|
825
831
|
dataset._session, Session
|
826
832
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -886,10 +892,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
886
892
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
887
893
|
|
888
894
|
if isinstance(dataset, DataFrame):
|
889
|
-
self.
|
890
|
-
|
891
|
-
inference_method=inference_method,
|
892
|
-
)
|
895
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
896
|
+
self._deps = self._get_dependencies()
|
893
897
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
894
898
|
transform_kwargs = dict(
|
895
899
|
session=dataset._session,
|
@@ -953,17 +957,15 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
953
957
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
954
958
|
|
955
959
|
if isinstance(dataset, DataFrame):
|
956
|
-
self.
|
957
|
-
|
958
|
-
inference_method="score",
|
959
|
-
)
|
960
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
961
|
+
self._deps = self._get_dependencies()
|
960
962
|
selected_cols = self._get_active_columns()
|
961
963
|
if len(selected_cols) > 0:
|
962
964
|
dataset = dataset.select(selected_cols)
|
963
965
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
964
966
|
transform_kwargs = dict(
|
965
967
|
session=dataset._session,
|
966
|
-
dependencies=
|
968
|
+
dependencies=self._deps,
|
967
969
|
score_sproc_imports=['sklearn'],
|
968
970
|
)
|
969
971
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1028,11 +1030,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
1028
1030
|
|
1029
1031
|
if isinstance(dataset, DataFrame):
|
1030
1032
|
|
1031
|
-
self.
|
1032
|
-
|
1033
|
-
inference_method=inference_method,
|
1034
|
-
|
1035
|
-
)
|
1033
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1034
|
+
self._deps = self._get_dependencies()
|
1036
1035
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1037
1036
|
transform_kwargs = dict(
|
1038
1037
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RadiusNeighborsRegressor(BaseTransformer):
|
70
64
|
r"""Regression based on neighbors within a fixed radius
|
71
65
|
For more details on this class, see [sklearn.neighbors.RadiusNeighborsRegressor]
|
@@ -336,20 +330,17 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
336
330
|
self,
|
337
331
|
dataset: DataFrame,
|
338
332
|
inference_method: str,
|
339
|
-
) ->
|
340
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
341
|
-
return the available package that exists in the snowflake anaconda channel
|
333
|
+
) -> None:
|
334
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
342
335
|
|
343
336
|
Args:
|
344
337
|
dataset: snowpark dataframe
|
345
338
|
inference_method: the inference method such as predict, score...
|
346
|
-
|
339
|
+
|
347
340
|
Raises:
|
348
341
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
349
342
|
SnowflakeMLException: If the session is None, raise error
|
350
343
|
|
351
|
-
Returns:
|
352
|
-
A list of available package that exists in the snowflake anaconda channel
|
353
344
|
"""
|
354
345
|
if not self._is_fitted:
|
355
346
|
raise exceptions.SnowflakeMLException(
|
@@ -367,9 +358,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
367
358
|
"Session must not specified for snowpark dataset."
|
368
359
|
),
|
369
360
|
)
|
370
|
-
|
371
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
372
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
361
|
+
|
373
362
|
|
374
363
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
375
364
|
@telemetry.send_api_usage_telemetry(
|
@@ -417,7 +406,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
417
406
|
|
418
407
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
419
408
|
|
420
|
-
self.
|
409
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
410
|
+
self._deps = self._get_dependencies()
|
421
411
|
assert isinstance(
|
422
412
|
dataset._session, Session
|
423
413
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -500,10 +490,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
500
490
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
501
491
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
502
492
|
|
503
|
-
self.
|
504
|
-
|
505
|
-
inference_method=inference_method,
|
506
|
-
)
|
493
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
494
|
+
self._deps = self._get_dependencies()
|
507
495
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
508
496
|
|
509
497
|
transform_kwargs = dict(
|
@@ -570,16 +558,40 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
570
558
|
self._is_fitted = True
|
571
559
|
return output_result
|
572
560
|
|
561
|
+
|
562
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
563
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
564
|
+
""" Method not supported for this class.
|
573
565
|
|
574
|
-
|
575
|
-
|
576
|
-
|
566
|
+
|
567
|
+
Raises:
|
568
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
569
|
+
|
570
|
+
Args:
|
571
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
572
|
+
Snowpark or Pandas DataFrame.
|
573
|
+
output_cols_prefix: Prefix for the response columns
|
577
574
|
Returns:
|
578
575
|
Transformed dataset.
|
579
576
|
"""
|
580
|
-
self.
|
581
|
-
|
582
|
-
|
577
|
+
self._infer_input_output_cols(dataset)
|
578
|
+
super()._check_dataset_type(dataset)
|
579
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
580
|
+
estimator=self._sklearn_object,
|
581
|
+
dataset=dataset,
|
582
|
+
input_cols=self.input_cols,
|
583
|
+
label_cols=self.label_cols,
|
584
|
+
sample_weight_col=self.sample_weight_col,
|
585
|
+
autogenerated=self._autogenerated,
|
586
|
+
subproject=_SUBPROJECT,
|
587
|
+
)
|
588
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
589
|
+
drop_input_cols=self._drop_input_cols,
|
590
|
+
expected_output_cols_list=self.output_cols,
|
591
|
+
)
|
592
|
+
self._sklearn_object = fitted_estimator
|
593
|
+
self._is_fitted = True
|
594
|
+
return output_result
|
583
595
|
|
584
596
|
|
585
597
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -670,10 +682,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
670
682
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
671
683
|
|
672
684
|
if isinstance(dataset, DataFrame):
|
673
|
-
self.
|
674
|
-
|
675
|
-
inference_method=inference_method,
|
676
|
-
)
|
685
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
686
|
+
self._deps = self._get_dependencies()
|
677
687
|
assert isinstance(
|
678
688
|
dataset._session, Session
|
679
689
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -738,10 +748,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
738
748
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
739
749
|
|
740
750
|
if isinstance(dataset, DataFrame):
|
741
|
-
self.
|
742
|
-
|
743
|
-
inference_method=inference_method,
|
744
|
-
)
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
745
753
|
assert isinstance(
|
746
754
|
dataset._session, Session
|
747
755
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -803,10 +811,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
803
811
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
804
812
|
|
805
813
|
if isinstance(dataset, DataFrame):
|
806
|
-
self.
|
807
|
-
|
808
|
-
inference_method=inference_method,
|
809
|
-
)
|
814
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
815
|
+
self._deps = self._get_dependencies()
|
810
816
|
assert isinstance(
|
811
817
|
dataset._session, Session
|
812
818
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -872,10 +878,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
872
878
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
873
879
|
|
874
880
|
if isinstance(dataset, DataFrame):
|
875
|
-
self.
|
876
|
-
|
877
|
-
inference_method=inference_method,
|
878
|
-
)
|
881
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
882
|
+
self._deps = self._get_dependencies()
|
879
883
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
880
884
|
transform_kwargs = dict(
|
881
885
|
session=dataset._session,
|
@@ -939,17 +943,15 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
939
943
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
940
944
|
|
941
945
|
if isinstance(dataset, DataFrame):
|
942
|
-
self.
|
943
|
-
|
944
|
-
inference_method="score",
|
945
|
-
)
|
946
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
947
|
+
self._deps = self._get_dependencies()
|
946
948
|
selected_cols = self._get_active_columns()
|
947
949
|
if len(selected_cols) > 0:
|
948
950
|
dataset = dataset.select(selected_cols)
|
949
951
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
950
952
|
transform_kwargs = dict(
|
951
953
|
session=dataset._session,
|
952
|
-
dependencies=
|
954
|
+
dependencies=self._deps,
|
953
955
|
score_sproc_imports=['sklearn'],
|
954
956
|
)
|
955
957
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1014,11 +1016,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
1014
1016
|
|
1015
1017
|
if isinstance(dataset, DataFrame):
|
1016
1018
|
|
1017
|
-
self.
|
1018
|
-
|
1019
|
-
inference_method=inference_method,
|
1020
|
-
|
1021
|
-
)
|
1019
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1020
|
+
self._deps = self._get_dependencies()
|
1022
1021
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1023
1022
|
transform_kwargs = dict(
|
1024
1023
|
session = dataset._session,
|