snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class Nystroem(BaseTransformer):
|
70
64
|
r"""Approximate a kernel map using a subset of the training data
|
71
65
|
For more details on this class, see [sklearn.kernel_approximation.Nystroem]
|
@@ -308,20 +302,17 @@ class Nystroem(BaseTransformer):
|
|
308
302
|
self,
|
309
303
|
dataset: DataFrame,
|
310
304
|
inference_method: str,
|
311
|
-
) ->
|
312
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
313
|
-
return the available package that exists in the snowflake anaconda channel
|
305
|
+
) -> None:
|
306
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
314
307
|
|
315
308
|
Args:
|
316
309
|
dataset: snowpark dataframe
|
317
310
|
inference_method: the inference method such as predict, score...
|
318
|
-
|
311
|
+
|
319
312
|
Raises:
|
320
313
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
321
314
|
SnowflakeMLException: If the session is None, raise error
|
322
315
|
|
323
|
-
Returns:
|
324
|
-
A list of available package that exists in the snowflake anaconda channel
|
325
316
|
"""
|
326
317
|
if not self._is_fitted:
|
327
318
|
raise exceptions.SnowflakeMLException(
|
@@ -339,9 +330,7 @@ class Nystroem(BaseTransformer):
|
|
339
330
|
"Session must not specified for snowpark dataset."
|
340
331
|
),
|
341
332
|
)
|
342
|
-
|
343
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
344
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
333
|
+
|
345
334
|
|
346
335
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
347
336
|
@telemetry.send_api_usage_telemetry(
|
@@ -387,7 +376,8 @@ class Nystroem(BaseTransformer):
|
|
387
376
|
|
388
377
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
389
378
|
|
390
|
-
self.
|
379
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
380
|
+
self._deps = self._get_dependencies()
|
391
381
|
assert isinstance(
|
392
382
|
dataset._session, Session
|
393
383
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -472,10 +462,8 @@ class Nystroem(BaseTransformer):
|
|
472
462
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
473
463
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
474
464
|
|
475
|
-
self.
|
476
|
-
|
477
|
-
inference_method=inference_method,
|
478
|
-
)
|
465
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
466
|
+
self._deps = self._get_dependencies()
|
479
467
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
480
468
|
|
481
469
|
transform_kwargs = dict(
|
@@ -542,16 +530,42 @@ class Nystroem(BaseTransformer):
|
|
542
530
|
self._is_fitted = True
|
543
531
|
return output_result
|
544
532
|
|
533
|
+
|
534
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
535
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
536
|
+
""" Fit to data, then transform it
|
537
|
+
For more details on this function, see [sklearn.kernel_approximation.Nystroem.fit_transform]
|
538
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.Nystroem.html#sklearn.kernel_approximation.Nystroem.fit_transform)
|
539
|
+
|
545
540
|
|
546
|
-
|
547
|
-
|
548
|
-
|
541
|
+
Raises:
|
542
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
543
|
+
|
544
|
+
Args:
|
545
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
546
|
+
Snowpark or Pandas DataFrame.
|
547
|
+
output_cols_prefix: Prefix for the response columns
|
549
548
|
Returns:
|
550
549
|
Transformed dataset.
|
551
550
|
"""
|
552
|
-
self.
|
553
|
-
|
554
|
-
|
551
|
+
self._infer_input_output_cols(dataset)
|
552
|
+
super()._check_dataset_type(dataset)
|
553
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
554
|
+
estimator=self._sklearn_object,
|
555
|
+
dataset=dataset,
|
556
|
+
input_cols=self.input_cols,
|
557
|
+
label_cols=self.label_cols,
|
558
|
+
sample_weight_col=self.sample_weight_col,
|
559
|
+
autogenerated=self._autogenerated,
|
560
|
+
subproject=_SUBPROJECT,
|
561
|
+
)
|
562
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
563
|
+
drop_input_cols=self._drop_input_cols,
|
564
|
+
expected_output_cols_list=self.output_cols,
|
565
|
+
)
|
566
|
+
self._sklearn_object = fitted_estimator
|
567
|
+
self._is_fitted = True
|
568
|
+
return output_result
|
555
569
|
|
556
570
|
|
557
571
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -642,10 +656,8 @@ class Nystroem(BaseTransformer):
|
|
642
656
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
643
657
|
|
644
658
|
if isinstance(dataset, DataFrame):
|
645
|
-
self.
|
646
|
-
|
647
|
-
inference_method=inference_method,
|
648
|
-
)
|
659
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
660
|
+
self._deps = self._get_dependencies()
|
649
661
|
assert isinstance(
|
650
662
|
dataset._session, Session
|
651
663
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -710,10 +722,8 @@ class Nystroem(BaseTransformer):
|
|
710
722
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
711
723
|
|
712
724
|
if isinstance(dataset, DataFrame):
|
713
|
-
self.
|
714
|
-
|
715
|
-
inference_method=inference_method,
|
716
|
-
)
|
725
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
726
|
+
self._deps = self._get_dependencies()
|
717
727
|
assert isinstance(
|
718
728
|
dataset._session, Session
|
719
729
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -775,10 +785,8 @@ class Nystroem(BaseTransformer):
|
|
775
785
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
776
786
|
|
777
787
|
if isinstance(dataset, DataFrame):
|
778
|
-
self.
|
779
|
-
|
780
|
-
inference_method=inference_method,
|
781
|
-
)
|
788
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
789
|
+
self._deps = self._get_dependencies()
|
782
790
|
assert isinstance(
|
783
791
|
dataset._session, Session
|
784
792
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -844,10 +852,8 @@ class Nystroem(BaseTransformer):
|
|
844
852
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
845
853
|
|
846
854
|
if isinstance(dataset, DataFrame):
|
847
|
-
self.
|
848
|
-
|
849
|
-
inference_method=inference_method,
|
850
|
-
)
|
855
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
856
|
+
self._deps = self._get_dependencies()
|
851
857
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
852
858
|
transform_kwargs = dict(
|
853
859
|
session=dataset._session,
|
@@ -909,17 +915,15 @@ class Nystroem(BaseTransformer):
|
|
909
915
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
910
916
|
|
911
917
|
if isinstance(dataset, DataFrame):
|
912
|
-
self.
|
913
|
-
|
914
|
-
inference_method="score",
|
915
|
-
)
|
918
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
919
|
+
self._deps = self._get_dependencies()
|
916
920
|
selected_cols = self._get_active_columns()
|
917
921
|
if len(selected_cols) > 0:
|
918
922
|
dataset = dataset.select(selected_cols)
|
919
923
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
920
924
|
transform_kwargs = dict(
|
921
925
|
session=dataset._session,
|
922
|
-
dependencies=
|
926
|
+
dependencies=self._deps,
|
923
927
|
score_sproc_imports=['sklearn'],
|
924
928
|
)
|
925
929
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -984,11 +988,8 @@ class Nystroem(BaseTransformer):
|
|
984
988
|
|
985
989
|
if isinstance(dataset, DataFrame):
|
986
990
|
|
987
|
-
self.
|
988
|
-
|
989
|
-
inference_method=inference_method,
|
990
|
-
|
991
|
-
)
|
991
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
992
|
+
self._deps = self._get_dependencies()
|
992
993
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
993
994
|
transform_kwargs = dict(
|
994
995
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class PolynomialCountSketch(BaseTransformer):
|
70
64
|
r"""Polynomial kernel approximation via Tensor Sketch
|
71
65
|
For more details on this class, see [sklearn.kernel_approximation.PolynomialCountSketch]
|
@@ -284,20 +278,17 @@ class PolynomialCountSketch(BaseTransformer):
|
|
284
278
|
self,
|
285
279
|
dataset: DataFrame,
|
286
280
|
inference_method: str,
|
287
|
-
) ->
|
288
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
289
|
-
return the available package that exists in the snowflake anaconda channel
|
281
|
+
) -> None:
|
282
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
290
283
|
|
291
284
|
Args:
|
292
285
|
dataset: snowpark dataframe
|
293
286
|
inference_method: the inference method such as predict, score...
|
294
|
-
|
287
|
+
|
295
288
|
Raises:
|
296
289
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
297
290
|
SnowflakeMLException: If the session is None, raise error
|
298
291
|
|
299
|
-
Returns:
|
300
|
-
A list of available package that exists in the snowflake anaconda channel
|
301
292
|
"""
|
302
293
|
if not self._is_fitted:
|
303
294
|
raise exceptions.SnowflakeMLException(
|
@@ -315,9 +306,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
315
306
|
"Session must not specified for snowpark dataset."
|
316
307
|
),
|
317
308
|
)
|
318
|
-
|
319
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
320
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
309
|
+
|
321
310
|
|
322
311
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
323
312
|
@telemetry.send_api_usage_telemetry(
|
@@ -363,7 +352,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
363
352
|
|
364
353
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
365
354
|
|
366
|
-
self.
|
355
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
356
|
+
self._deps = self._get_dependencies()
|
367
357
|
assert isinstance(
|
368
358
|
dataset._session, Session
|
369
359
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -448,10 +438,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
448
438
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
449
439
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
450
440
|
|
451
|
-
self.
|
452
|
-
|
453
|
-
inference_method=inference_method,
|
454
|
-
)
|
441
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
442
|
+
self._deps = self._get_dependencies()
|
455
443
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
456
444
|
|
457
445
|
transform_kwargs = dict(
|
@@ -518,16 +506,42 @@ class PolynomialCountSketch(BaseTransformer):
|
|
518
506
|
self._is_fitted = True
|
519
507
|
return output_result
|
520
508
|
|
509
|
+
|
510
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
511
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
512
|
+
""" Fit to data, then transform it
|
513
|
+
For more details on this function, see [sklearn.kernel_approximation.PolynomialCountSketch.fit_transform]
|
514
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.PolynomialCountSketch.html#sklearn.kernel_approximation.PolynomialCountSketch.fit_transform)
|
515
|
+
|
521
516
|
|
522
|
-
|
523
|
-
|
524
|
-
|
517
|
+
Raises:
|
518
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
519
|
+
|
520
|
+
Args:
|
521
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
522
|
+
Snowpark or Pandas DataFrame.
|
523
|
+
output_cols_prefix: Prefix for the response columns
|
525
524
|
Returns:
|
526
525
|
Transformed dataset.
|
527
526
|
"""
|
528
|
-
self.
|
529
|
-
|
530
|
-
|
527
|
+
self._infer_input_output_cols(dataset)
|
528
|
+
super()._check_dataset_type(dataset)
|
529
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
530
|
+
estimator=self._sklearn_object,
|
531
|
+
dataset=dataset,
|
532
|
+
input_cols=self.input_cols,
|
533
|
+
label_cols=self.label_cols,
|
534
|
+
sample_weight_col=self.sample_weight_col,
|
535
|
+
autogenerated=self._autogenerated,
|
536
|
+
subproject=_SUBPROJECT,
|
537
|
+
)
|
538
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
539
|
+
drop_input_cols=self._drop_input_cols,
|
540
|
+
expected_output_cols_list=self.output_cols,
|
541
|
+
)
|
542
|
+
self._sklearn_object = fitted_estimator
|
543
|
+
self._is_fitted = True
|
544
|
+
return output_result
|
531
545
|
|
532
546
|
|
533
547
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -618,10 +632,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
618
632
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
619
633
|
|
620
634
|
if isinstance(dataset, DataFrame):
|
621
|
-
self.
|
622
|
-
|
623
|
-
inference_method=inference_method,
|
624
|
-
)
|
635
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
636
|
+
self._deps = self._get_dependencies()
|
625
637
|
assert isinstance(
|
626
638
|
dataset._session, Session
|
627
639
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -686,10 +698,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
686
698
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
687
699
|
|
688
700
|
if isinstance(dataset, DataFrame):
|
689
|
-
self.
|
690
|
-
|
691
|
-
inference_method=inference_method,
|
692
|
-
)
|
701
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
702
|
+
self._deps = self._get_dependencies()
|
693
703
|
assert isinstance(
|
694
704
|
dataset._session, Session
|
695
705
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -751,10 +761,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
751
761
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
752
762
|
|
753
763
|
if isinstance(dataset, DataFrame):
|
754
|
-
self.
|
755
|
-
|
756
|
-
inference_method=inference_method,
|
757
|
-
)
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
758
766
|
assert isinstance(
|
759
767
|
dataset._session, Session
|
760
768
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -820,10 +828,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
820
828
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
821
829
|
|
822
830
|
if isinstance(dataset, DataFrame):
|
823
|
-
self.
|
824
|
-
|
825
|
-
inference_method=inference_method,
|
826
|
-
)
|
831
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
832
|
+
self._deps = self._get_dependencies()
|
827
833
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
828
834
|
transform_kwargs = dict(
|
829
835
|
session=dataset._session,
|
@@ -885,17 +891,15 @@ class PolynomialCountSketch(BaseTransformer):
|
|
885
891
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
886
892
|
|
887
893
|
if isinstance(dataset, DataFrame):
|
888
|
-
self.
|
889
|
-
|
890
|
-
inference_method="score",
|
891
|
-
)
|
894
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
895
|
+
self._deps = self._get_dependencies()
|
892
896
|
selected_cols = self._get_active_columns()
|
893
897
|
if len(selected_cols) > 0:
|
894
898
|
dataset = dataset.select(selected_cols)
|
895
899
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
896
900
|
transform_kwargs = dict(
|
897
901
|
session=dataset._session,
|
898
|
-
dependencies=
|
902
|
+
dependencies=self._deps,
|
899
903
|
score_sproc_imports=['sklearn'],
|
900
904
|
)
|
901
905
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -960,11 +964,8 @@ class PolynomialCountSketch(BaseTransformer):
|
|
960
964
|
|
961
965
|
if isinstance(dataset, DataFrame):
|
962
966
|
|
963
|
-
self.
|
964
|
-
|
965
|
-
inference_method=inference_method,
|
966
|
-
|
967
|
-
)
|
967
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
968
|
+
self._deps = self._get_dependencies()
|
968
969
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
969
970
|
transform_kwargs = dict(
|
970
971
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class RBFSampler(BaseTransformer):
|
70
64
|
r"""Approximate a RBF kernel feature map using random Fourier features
|
71
65
|
For more details on this class, see [sklearn.kernel_approximation.RBFSampler]
|
@@ -271,20 +265,17 @@ class RBFSampler(BaseTransformer):
|
|
271
265
|
self,
|
272
266
|
dataset: DataFrame,
|
273
267
|
inference_method: str,
|
274
|
-
) ->
|
275
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
276
|
-
return the available package that exists in the snowflake anaconda channel
|
268
|
+
) -> None:
|
269
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
277
270
|
|
278
271
|
Args:
|
279
272
|
dataset: snowpark dataframe
|
280
273
|
inference_method: the inference method such as predict, score...
|
281
|
-
|
274
|
+
|
282
275
|
Raises:
|
283
276
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
284
277
|
SnowflakeMLException: If the session is None, raise error
|
285
278
|
|
286
|
-
Returns:
|
287
|
-
A list of available package that exists in the snowflake anaconda channel
|
288
279
|
"""
|
289
280
|
if not self._is_fitted:
|
290
281
|
raise exceptions.SnowflakeMLException(
|
@@ -302,9 +293,7 @@ class RBFSampler(BaseTransformer):
|
|
302
293
|
"Session must not specified for snowpark dataset."
|
303
294
|
),
|
304
295
|
)
|
305
|
-
|
306
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
307
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
296
|
+
|
308
297
|
|
309
298
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
310
299
|
@telemetry.send_api_usage_telemetry(
|
@@ -350,7 +339,8 @@ class RBFSampler(BaseTransformer):
|
|
350
339
|
|
351
340
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
352
341
|
|
353
|
-
self.
|
342
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
343
|
+
self._deps = self._get_dependencies()
|
354
344
|
assert isinstance(
|
355
345
|
dataset._session, Session
|
356
346
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -435,10 +425,8 @@ class RBFSampler(BaseTransformer):
|
|
435
425
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
436
426
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
437
427
|
|
438
|
-
self.
|
439
|
-
|
440
|
-
inference_method=inference_method,
|
441
|
-
)
|
428
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
429
|
+
self._deps = self._get_dependencies()
|
442
430
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
443
431
|
|
444
432
|
transform_kwargs = dict(
|
@@ -505,16 +493,42 @@ class RBFSampler(BaseTransformer):
|
|
505
493
|
self._is_fitted = True
|
506
494
|
return output_result
|
507
495
|
|
496
|
+
|
497
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
498
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
499
|
+
""" Fit to data, then transform it
|
500
|
+
For more details on this function, see [sklearn.kernel_approximation.RBFSampler.fit_transform]
|
501
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.RBFSampler.html#sklearn.kernel_approximation.RBFSampler.fit_transform)
|
502
|
+
|
508
503
|
|
509
|
-
|
510
|
-
|
511
|
-
|
504
|
+
Raises:
|
505
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
506
|
+
|
507
|
+
Args:
|
508
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
509
|
+
Snowpark or Pandas DataFrame.
|
510
|
+
output_cols_prefix: Prefix for the response columns
|
512
511
|
Returns:
|
513
512
|
Transformed dataset.
|
514
513
|
"""
|
515
|
-
self.
|
516
|
-
|
517
|
-
|
514
|
+
self._infer_input_output_cols(dataset)
|
515
|
+
super()._check_dataset_type(dataset)
|
516
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
517
|
+
estimator=self._sklearn_object,
|
518
|
+
dataset=dataset,
|
519
|
+
input_cols=self.input_cols,
|
520
|
+
label_cols=self.label_cols,
|
521
|
+
sample_weight_col=self.sample_weight_col,
|
522
|
+
autogenerated=self._autogenerated,
|
523
|
+
subproject=_SUBPROJECT,
|
524
|
+
)
|
525
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
526
|
+
drop_input_cols=self._drop_input_cols,
|
527
|
+
expected_output_cols_list=self.output_cols,
|
528
|
+
)
|
529
|
+
self._sklearn_object = fitted_estimator
|
530
|
+
self._is_fitted = True
|
531
|
+
return output_result
|
518
532
|
|
519
533
|
|
520
534
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -605,10 +619,8 @@ class RBFSampler(BaseTransformer):
|
|
605
619
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
606
620
|
|
607
621
|
if isinstance(dataset, DataFrame):
|
608
|
-
self.
|
609
|
-
|
610
|
-
inference_method=inference_method,
|
611
|
-
)
|
622
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
623
|
+
self._deps = self._get_dependencies()
|
612
624
|
assert isinstance(
|
613
625
|
dataset._session, Session
|
614
626
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -673,10 +685,8 @@ class RBFSampler(BaseTransformer):
|
|
673
685
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
674
686
|
|
675
687
|
if isinstance(dataset, DataFrame):
|
676
|
-
self.
|
677
|
-
|
678
|
-
inference_method=inference_method,
|
679
|
-
)
|
688
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
689
|
+
self._deps = self._get_dependencies()
|
680
690
|
assert isinstance(
|
681
691
|
dataset._session, Session
|
682
692
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -738,10 +748,8 @@ class RBFSampler(BaseTransformer):
|
|
738
748
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
739
749
|
|
740
750
|
if isinstance(dataset, DataFrame):
|
741
|
-
self.
|
742
|
-
|
743
|
-
inference_method=inference_method,
|
744
|
-
)
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
745
753
|
assert isinstance(
|
746
754
|
dataset._session, Session
|
747
755
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -807,10 +815,8 @@ class RBFSampler(BaseTransformer):
|
|
807
815
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
808
816
|
|
809
817
|
if isinstance(dataset, DataFrame):
|
810
|
-
self.
|
811
|
-
|
812
|
-
inference_method=inference_method,
|
813
|
-
)
|
818
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
819
|
+
self._deps = self._get_dependencies()
|
814
820
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
815
821
|
transform_kwargs = dict(
|
816
822
|
session=dataset._session,
|
@@ -872,17 +878,15 @@ class RBFSampler(BaseTransformer):
|
|
872
878
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
873
879
|
|
874
880
|
if isinstance(dataset, DataFrame):
|
875
|
-
self.
|
876
|
-
|
877
|
-
inference_method="score",
|
878
|
-
)
|
881
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
882
|
+
self._deps = self._get_dependencies()
|
879
883
|
selected_cols = self._get_active_columns()
|
880
884
|
if len(selected_cols) > 0:
|
881
885
|
dataset = dataset.select(selected_cols)
|
882
886
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
883
887
|
transform_kwargs = dict(
|
884
888
|
session=dataset._session,
|
885
|
-
dependencies=
|
889
|
+
dependencies=self._deps,
|
886
890
|
score_sproc_imports=['sklearn'],
|
887
891
|
)
|
888
892
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -947,11 +951,8 @@ class RBFSampler(BaseTransformer):
|
|
947
951
|
|
948
952
|
if isinstance(dataset, DataFrame):
|
949
953
|
|
950
|
-
self.
|
951
|
-
|
952
|
-
inference_method=inference_method,
|
953
|
-
|
954
|
-
)
|
954
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
955
|
+
self._deps = self._get_dependencies()
|
955
956
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
956
957
|
transform_kwargs = dict(
|
957
958
|
session = dataset._session,
|