snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class KNeighborsClassifier(BaseTransformer):
|
70
64
|
r"""Classifier implementing the k-nearest neighbors vote
|
71
65
|
For more details on this class, see [sklearn.neighbors.KNeighborsClassifier]
|
@@ -334,20 +328,17 @@ class KNeighborsClassifier(BaseTransformer):
|
|
334
328
|
self,
|
335
329
|
dataset: DataFrame,
|
336
330
|
inference_method: str,
|
337
|
-
) ->
|
338
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
339
|
-
return the available package that exists in the snowflake anaconda channel
|
331
|
+
) -> None:
|
332
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
340
333
|
|
341
334
|
Args:
|
342
335
|
dataset: snowpark dataframe
|
343
336
|
inference_method: the inference method such as predict, score...
|
344
|
-
|
337
|
+
|
345
338
|
Raises:
|
346
339
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
347
340
|
SnowflakeMLException: If the session is None, raise error
|
348
341
|
|
349
|
-
Returns:
|
350
|
-
A list of available package that exists in the snowflake anaconda channel
|
351
342
|
"""
|
352
343
|
if not self._is_fitted:
|
353
344
|
raise exceptions.SnowflakeMLException(
|
@@ -365,9 +356,7 @@ class KNeighborsClassifier(BaseTransformer):
|
|
365
356
|
"Session must not specified for snowpark dataset."
|
366
357
|
),
|
367
358
|
)
|
368
|
-
|
369
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
370
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
359
|
+
|
371
360
|
|
372
361
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
373
362
|
@telemetry.send_api_usage_telemetry(
|
@@ -415,7 +404,8 @@ class KNeighborsClassifier(BaseTransformer):
|
|
415
404
|
|
416
405
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
417
406
|
|
418
|
-
self.
|
407
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
408
|
+
self._deps = self._get_dependencies()
|
419
409
|
assert isinstance(
|
420
410
|
dataset._session, Session
|
421
411
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -498,10 +488,8 @@ class KNeighborsClassifier(BaseTransformer):
|
|
498
488
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
499
489
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
500
490
|
|
501
|
-
self.
|
502
|
-
|
503
|
-
inference_method=inference_method,
|
504
|
-
)
|
491
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
492
|
+
self._deps = self._get_dependencies()
|
505
493
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
506
494
|
|
507
495
|
transform_kwargs = dict(
|
@@ -568,16 +556,40 @@ class KNeighborsClassifier(BaseTransformer):
|
|
568
556
|
self._is_fitted = True
|
569
557
|
return output_result
|
570
558
|
|
559
|
+
|
560
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
561
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
562
|
+
""" Method not supported for this class.
|
571
563
|
|
572
|
-
|
573
|
-
|
574
|
-
|
564
|
+
|
565
|
+
Raises:
|
566
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
567
|
+
|
568
|
+
Args:
|
569
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
570
|
+
Snowpark or Pandas DataFrame.
|
571
|
+
output_cols_prefix: Prefix for the response columns
|
575
572
|
Returns:
|
576
573
|
Transformed dataset.
|
577
574
|
"""
|
578
|
-
self.
|
579
|
-
|
580
|
-
|
575
|
+
self._infer_input_output_cols(dataset)
|
576
|
+
super()._check_dataset_type(dataset)
|
577
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
578
|
+
estimator=self._sklearn_object,
|
579
|
+
dataset=dataset,
|
580
|
+
input_cols=self.input_cols,
|
581
|
+
label_cols=self.label_cols,
|
582
|
+
sample_weight_col=self.sample_weight_col,
|
583
|
+
autogenerated=self._autogenerated,
|
584
|
+
subproject=_SUBPROJECT,
|
585
|
+
)
|
586
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
587
|
+
drop_input_cols=self._drop_input_cols,
|
588
|
+
expected_output_cols_list=self.output_cols,
|
589
|
+
)
|
590
|
+
self._sklearn_object = fitted_estimator
|
591
|
+
self._is_fitted = True
|
592
|
+
return output_result
|
581
593
|
|
582
594
|
|
583
595
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -670,10 +682,8 @@ class KNeighborsClassifier(BaseTransformer):
|
|
670
682
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
671
683
|
|
672
684
|
if isinstance(dataset, DataFrame):
|
673
|
-
self.
|
674
|
-
|
675
|
-
inference_method=inference_method,
|
676
|
-
)
|
685
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
686
|
+
self._deps = self._get_dependencies()
|
677
687
|
assert isinstance(
|
678
688
|
dataset._session, Session
|
679
689
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -740,10 +750,8 @@ class KNeighborsClassifier(BaseTransformer):
|
|
740
750
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
741
751
|
|
742
752
|
if isinstance(dataset, DataFrame):
|
743
|
-
self.
|
744
|
-
|
745
|
-
inference_method=inference_method,
|
746
|
-
)
|
753
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
754
|
+
self._deps = self._get_dependencies()
|
747
755
|
assert isinstance(
|
748
756
|
dataset._session, Session
|
749
757
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -805,10 +813,8 @@ class KNeighborsClassifier(BaseTransformer):
|
|
805
813
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
806
814
|
|
807
815
|
if isinstance(dataset, DataFrame):
|
808
|
-
self.
|
809
|
-
|
810
|
-
inference_method=inference_method,
|
811
|
-
)
|
816
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
817
|
+
self._deps = self._get_dependencies()
|
812
818
|
assert isinstance(
|
813
819
|
dataset._session, Session
|
814
820
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -874,10 +880,8 @@ class KNeighborsClassifier(BaseTransformer):
|
|
874
880
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
875
881
|
|
876
882
|
if isinstance(dataset, DataFrame):
|
877
|
-
self.
|
878
|
-
|
879
|
-
inference_method=inference_method,
|
880
|
-
)
|
883
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
884
|
+
self._deps = self._get_dependencies()
|
881
885
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
882
886
|
transform_kwargs = dict(
|
883
887
|
session=dataset._session,
|
@@ -941,17 +945,15 @@ class KNeighborsClassifier(BaseTransformer):
|
|
941
945
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
942
946
|
|
943
947
|
if isinstance(dataset, DataFrame):
|
944
|
-
self.
|
945
|
-
|
946
|
-
inference_method="score",
|
947
|
-
)
|
948
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
949
|
+
self._deps = self._get_dependencies()
|
948
950
|
selected_cols = self._get_active_columns()
|
949
951
|
if len(selected_cols) > 0:
|
950
952
|
dataset = dataset.select(selected_cols)
|
951
953
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
952
954
|
transform_kwargs = dict(
|
953
955
|
session=dataset._session,
|
954
|
-
dependencies=
|
956
|
+
dependencies=self._deps,
|
955
957
|
score_sproc_imports=['sklearn'],
|
956
958
|
)
|
957
959
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1018,11 +1020,8 @@ class KNeighborsClassifier(BaseTransformer):
|
|
1018
1020
|
|
1019
1021
|
if isinstance(dataset, DataFrame):
|
1020
1022
|
|
1021
|
-
self.
|
1022
|
-
|
1023
|
-
inference_method=inference_method,
|
1024
|
-
|
1025
|
-
)
|
1023
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1024
|
+
self._deps = self._get_dependencies()
|
1026
1025
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1027
1026
|
transform_kwargs = dict(
|
1028
1027
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class KNeighborsRegressor(BaseTransformer):
|
70
64
|
r"""Regression based on k-nearest neighbors
|
71
65
|
For more details on this class, see [sklearn.neighbors.KNeighborsRegressor]
|
@@ -336,20 +330,17 @@ class KNeighborsRegressor(BaseTransformer):
|
|
336
330
|
self,
|
337
331
|
dataset: DataFrame,
|
338
332
|
inference_method: str,
|
339
|
-
) ->
|
340
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
341
|
-
return the available package that exists in the snowflake anaconda channel
|
333
|
+
) -> None:
|
334
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
342
335
|
|
343
336
|
Args:
|
344
337
|
dataset: snowpark dataframe
|
345
338
|
inference_method: the inference method such as predict, score...
|
346
|
-
|
339
|
+
|
347
340
|
Raises:
|
348
341
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
349
342
|
SnowflakeMLException: If the session is None, raise error
|
350
343
|
|
351
|
-
Returns:
|
352
|
-
A list of available package that exists in the snowflake anaconda channel
|
353
344
|
"""
|
354
345
|
if not self._is_fitted:
|
355
346
|
raise exceptions.SnowflakeMLException(
|
@@ -367,9 +358,7 @@ class KNeighborsRegressor(BaseTransformer):
|
|
367
358
|
"Session must not specified for snowpark dataset."
|
368
359
|
),
|
369
360
|
)
|
370
|
-
|
371
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
372
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
361
|
+
|
373
362
|
|
374
363
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
375
364
|
@telemetry.send_api_usage_telemetry(
|
@@ -417,7 +406,8 @@ class KNeighborsRegressor(BaseTransformer):
|
|
417
406
|
|
418
407
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
419
408
|
|
420
|
-
self.
|
409
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
410
|
+
self._deps = self._get_dependencies()
|
421
411
|
assert isinstance(
|
422
412
|
dataset._session, Session
|
423
413
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -500,10 +490,8 @@ class KNeighborsRegressor(BaseTransformer):
|
|
500
490
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
501
491
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
502
492
|
|
503
|
-
self.
|
504
|
-
|
505
|
-
inference_method=inference_method,
|
506
|
-
)
|
493
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
494
|
+
self._deps = self._get_dependencies()
|
507
495
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
508
496
|
|
509
497
|
transform_kwargs = dict(
|
@@ -570,16 +558,40 @@ class KNeighborsRegressor(BaseTransformer):
|
|
570
558
|
self._is_fitted = True
|
571
559
|
return output_result
|
572
560
|
|
561
|
+
|
562
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
563
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
564
|
+
""" Method not supported for this class.
|
573
565
|
|
574
|
-
|
575
|
-
|
576
|
-
|
566
|
+
|
567
|
+
Raises:
|
568
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
569
|
+
|
570
|
+
Args:
|
571
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
572
|
+
Snowpark or Pandas DataFrame.
|
573
|
+
output_cols_prefix: Prefix for the response columns
|
577
574
|
Returns:
|
578
575
|
Transformed dataset.
|
579
576
|
"""
|
580
|
-
self.
|
581
|
-
|
582
|
-
|
577
|
+
self._infer_input_output_cols(dataset)
|
578
|
+
super()._check_dataset_type(dataset)
|
579
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
580
|
+
estimator=self._sklearn_object,
|
581
|
+
dataset=dataset,
|
582
|
+
input_cols=self.input_cols,
|
583
|
+
label_cols=self.label_cols,
|
584
|
+
sample_weight_col=self.sample_weight_col,
|
585
|
+
autogenerated=self._autogenerated,
|
586
|
+
subproject=_SUBPROJECT,
|
587
|
+
)
|
588
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
589
|
+
drop_input_cols=self._drop_input_cols,
|
590
|
+
expected_output_cols_list=self.output_cols,
|
591
|
+
)
|
592
|
+
self._sklearn_object = fitted_estimator
|
593
|
+
self._is_fitted = True
|
594
|
+
return output_result
|
583
595
|
|
584
596
|
|
585
597
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -670,10 +682,8 @@ class KNeighborsRegressor(BaseTransformer):
|
|
670
682
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
671
683
|
|
672
684
|
if isinstance(dataset, DataFrame):
|
673
|
-
self.
|
674
|
-
|
675
|
-
inference_method=inference_method,
|
676
|
-
)
|
685
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
686
|
+
self._deps = self._get_dependencies()
|
677
687
|
assert isinstance(
|
678
688
|
dataset._session, Session
|
679
689
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -738,10 +748,8 @@ class KNeighborsRegressor(BaseTransformer):
|
|
738
748
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
739
749
|
|
740
750
|
if isinstance(dataset, DataFrame):
|
741
|
-
self.
|
742
|
-
|
743
|
-
inference_method=inference_method,
|
744
|
-
)
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
745
753
|
assert isinstance(
|
746
754
|
dataset._session, Session
|
747
755
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -803,10 +811,8 @@ class KNeighborsRegressor(BaseTransformer):
|
|
803
811
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
804
812
|
|
805
813
|
if isinstance(dataset, DataFrame):
|
806
|
-
self.
|
807
|
-
|
808
|
-
inference_method=inference_method,
|
809
|
-
)
|
814
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
815
|
+
self._deps = self._get_dependencies()
|
810
816
|
assert isinstance(
|
811
817
|
dataset._session, Session
|
812
818
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -872,10 +878,8 @@ class KNeighborsRegressor(BaseTransformer):
|
|
872
878
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
873
879
|
|
874
880
|
if isinstance(dataset, DataFrame):
|
875
|
-
self.
|
876
|
-
|
877
|
-
inference_method=inference_method,
|
878
|
-
)
|
881
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
882
|
+
self._deps = self._get_dependencies()
|
879
883
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
880
884
|
transform_kwargs = dict(
|
881
885
|
session=dataset._session,
|
@@ -939,17 +943,15 @@ class KNeighborsRegressor(BaseTransformer):
|
|
939
943
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
940
944
|
|
941
945
|
if isinstance(dataset, DataFrame):
|
942
|
-
self.
|
943
|
-
|
944
|
-
inference_method="score",
|
945
|
-
)
|
946
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
947
|
+
self._deps = self._get_dependencies()
|
946
948
|
selected_cols = self._get_active_columns()
|
947
949
|
if len(selected_cols) > 0:
|
948
950
|
dataset = dataset.select(selected_cols)
|
949
951
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
950
952
|
transform_kwargs = dict(
|
951
953
|
session=dataset._session,
|
952
|
-
dependencies=
|
954
|
+
dependencies=self._deps,
|
953
955
|
score_sproc_imports=['sklearn'],
|
954
956
|
)
|
955
957
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1016,11 +1018,8 @@ class KNeighborsRegressor(BaseTransformer):
|
|
1016
1018
|
|
1017
1019
|
if isinstance(dataset, DataFrame):
|
1018
1020
|
|
1019
|
-
self.
|
1020
|
-
|
1021
|
-
inference_method=inference_method,
|
1022
|
-
|
1023
|
-
)
|
1021
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1022
|
+
self._deps = self._get_dependencies()
|
1024
1023
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1025
1024
|
transform_kwargs = dict(
|
1026
1025
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class KernelDensity(BaseTransformer):
|
70
64
|
r"""Kernel Density Estimation
|
71
65
|
For more details on this class, see [sklearn.neighbors.KernelDensity]
|
@@ -313,20 +307,17 @@ class KernelDensity(BaseTransformer):
|
|
313
307
|
self,
|
314
308
|
dataset: DataFrame,
|
315
309
|
inference_method: str,
|
316
|
-
) ->
|
317
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
318
|
-
return the available package that exists in the snowflake anaconda channel
|
310
|
+
) -> None:
|
311
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
319
312
|
|
320
313
|
Args:
|
321
314
|
dataset: snowpark dataframe
|
322
315
|
inference_method: the inference method such as predict, score...
|
323
|
-
|
316
|
+
|
324
317
|
Raises:
|
325
318
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
326
319
|
SnowflakeMLException: If the session is None, raise error
|
327
320
|
|
328
|
-
Returns:
|
329
|
-
A list of available package that exists in the snowflake anaconda channel
|
330
321
|
"""
|
331
322
|
if not self._is_fitted:
|
332
323
|
raise exceptions.SnowflakeMLException(
|
@@ -344,9 +335,7 @@ class KernelDensity(BaseTransformer):
|
|
344
335
|
"Session must not specified for snowpark dataset."
|
345
336
|
),
|
346
337
|
)
|
347
|
-
|
348
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
349
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
338
|
+
|
350
339
|
|
351
340
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
352
341
|
@telemetry.send_api_usage_telemetry(
|
@@ -392,7 +381,8 @@ class KernelDensity(BaseTransformer):
|
|
392
381
|
|
393
382
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
394
383
|
|
395
|
-
self.
|
384
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
385
|
+
self._deps = self._get_dependencies()
|
396
386
|
assert isinstance(
|
397
387
|
dataset._session, Session
|
398
388
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -475,10 +465,8 @@ class KernelDensity(BaseTransformer):
|
|
475
465
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
476
466
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
477
467
|
|
478
|
-
self.
|
479
|
-
|
480
|
-
inference_method=inference_method,
|
481
|
-
)
|
468
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
469
|
+
self._deps = self._get_dependencies()
|
482
470
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
483
471
|
|
484
472
|
transform_kwargs = dict(
|
@@ -545,16 +533,40 @@ class KernelDensity(BaseTransformer):
|
|
545
533
|
self._is_fitted = True
|
546
534
|
return output_result
|
547
535
|
|
536
|
+
|
537
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
538
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
539
|
+
""" Method not supported for this class.
|
548
540
|
|
549
|
-
|
550
|
-
|
551
|
-
|
541
|
+
|
542
|
+
Raises:
|
543
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
544
|
+
|
545
|
+
Args:
|
546
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
547
|
+
Snowpark or Pandas DataFrame.
|
548
|
+
output_cols_prefix: Prefix for the response columns
|
552
549
|
Returns:
|
553
550
|
Transformed dataset.
|
554
551
|
"""
|
555
|
-
self.
|
556
|
-
|
557
|
-
|
552
|
+
self._infer_input_output_cols(dataset)
|
553
|
+
super()._check_dataset_type(dataset)
|
554
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
555
|
+
estimator=self._sklearn_object,
|
556
|
+
dataset=dataset,
|
557
|
+
input_cols=self.input_cols,
|
558
|
+
label_cols=self.label_cols,
|
559
|
+
sample_weight_col=self.sample_weight_col,
|
560
|
+
autogenerated=self._autogenerated,
|
561
|
+
subproject=_SUBPROJECT,
|
562
|
+
)
|
563
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
564
|
+
drop_input_cols=self._drop_input_cols,
|
565
|
+
expected_output_cols_list=self.output_cols,
|
566
|
+
)
|
567
|
+
self._sklearn_object = fitted_estimator
|
568
|
+
self._is_fitted = True
|
569
|
+
return output_result
|
558
570
|
|
559
571
|
|
560
572
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -645,10 +657,8 @@ class KernelDensity(BaseTransformer):
|
|
645
657
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
646
658
|
|
647
659
|
if isinstance(dataset, DataFrame):
|
648
|
-
self.
|
649
|
-
|
650
|
-
inference_method=inference_method,
|
651
|
-
)
|
660
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
661
|
+
self._deps = self._get_dependencies()
|
652
662
|
assert isinstance(
|
653
663
|
dataset._session, Session
|
654
664
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -713,10 +723,8 @@ class KernelDensity(BaseTransformer):
|
|
713
723
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
714
724
|
|
715
725
|
if isinstance(dataset, DataFrame):
|
716
|
-
self.
|
717
|
-
|
718
|
-
inference_method=inference_method,
|
719
|
-
)
|
726
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
727
|
+
self._deps = self._get_dependencies()
|
720
728
|
assert isinstance(
|
721
729
|
dataset._session, Session
|
722
730
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -778,10 +786,8 @@ class KernelDensity(BaseTransformer):
|
|
778
786
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
779
787
|
|
780
788
|
if isinstance(dataset, DataFrame):
|
781
|
-
self.
|
782
|
-
|
783
|
-
inference_method=inference_method,
|
784
|
-
)
|
789
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
790
|
+
self._deps = self._get_dependencies()
|
785
791
|
assert isinstance(
|
786
792
|
dataset._session, Session
|
787
793
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -849,10 +855,8 @@ class KernelDensity(BaseTransformer):
|
|
849
855
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
850
856
|
|
851
857
|
if isinstance(dataset, DataFrame):
|
852
|
-
self.
|
853
|
-
|
854
|
-
inference_method=inference_method,
|
855
|
-
)
|
858
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
859
|
+
self._deps = self._get_dependencies()
|
856
860
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
857
861
|
transform_kwargs = dict(
|
858
862
|
session=dataset._session,
|
@@ -916,17 +920,15 @@ class KernelDensity(BaseTransformer):
|
|
916
920
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
917
921
|
|
918
922
|
if isinstance(dataset, DataFrame):
|
919
|
-
self.
|
920
|
-
|
921
|
-
inference_method="score",
|
922
|
-
)
|
923
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
924
|
+
self._deps = self._get_dependencies()
|
923
925
|
selected_cols = self._get_active_columns()
|
924
926
|
if len(selected_cols) > 0:
|
925
927
|
dataset = dataset.select(selected_cols)
|
926
928
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
927
929
|
transform_kwargs = dict(
|
928
930
|
session=dataset._session,
|
929
|
-
dependencies=
|
931
|
+
dependencies=self._deps,
|
930
932
|
score_sproc_imports=['sklearn'],
|
931
933
|
)
|
932
934
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -991,11 +993,8 @@ class KernelDensity(BaseTransformer):
|
|
991
993
|
|
992
994
|
if isinstance(dataset, DataFrame):
|
993
995
|
|
994
|
-
self.
|
995
|
-
|
996
|
-
inference_method=inference_method,
|
997
|
-
|
998
|
-
)
|
996
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
997
|
+
self._deps = self._get_dependencies()
|
999
998
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1000
999
|
transform_kwargs = dict(
|
1001
1000
|
session = dataset._session,
|