snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class BisectingKMeans(BaseTransformer):
|
70
64
|
r"""Bisecting K-Means clustering
|
71
65
|
For more details on this class, see [sklearn.cluster.BisectingKMeans]
|
@@ -343,20 +337,17 @@ class BisectingKMeans(BaseTransformer):
|
|
343
337
|
self,
|
344
338
|
dataset: DataFrame,
|
345
339
|
inference_method: str,
|
346
|
-
) ->
|
347
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
348
|
-
return the available package that exists in the snowflake anaconda channel
|
340
|
+
) -> None:
|
341
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
349
342
|
|
350
343
|
Args:
|
351
344
|
dataset: snowpark dataframe
|
352
345
|
inference_method: the inference method such as predict, score...
|
353
|
-
|
346
|
+
|
354
347
|
Raises:
|
355
348
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
356
349
|
SnowflakeMLException: If the session is None, raise error
|
357
350
|
|
358
|
-
Returns:
|
359
|
-
A list of available package that exists in the snowflake anaconda channel
|
360
351
|
"""
|
361
352
|
if not self._is_fitted:
|
362
353
|
raise exceptions.SnowflakeMLException(
|
@@ -374,9 +365,7 @@ class BisectingKMeans(BaseTransformer):
|
|
374
365
|
"Session must not specified for snowpark dataset."
|
375
366
|
),
|
376
367
|
)
|
377
|
-
|
378
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
379
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
368
|
+
|
380
369
|
|
381
370
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
382
371
|
@telemetry.send_api_usage_telemetry(
|
@@ -424,7 +413,8 @@ class BisectingKMeans(BaseTransformer):
|
|
424
413
|
|
425
414
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
426
415
|
|
427
|
-
self.
|
416
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
417
|
+
self._deps = self._get_dependencies()
|
428
418
|
assert isinstance(
|
429
419
|
dataset._session, Session
|
430
420
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -509,10 +499,8 @@ class BisectingKMeans(BaseTransformer):
|
|
509
499
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
510
500
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
511
501
|
|
512
|
-
self.
|
513
|
-
|
514
|
-
inference_method=inference_method,
|
515
|
-
)
|
502
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
503
|
+
self._deps = self._get_dependencies()
|
516
504
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
517
505
|
|
518
506
|
transform_kwargs = dict(
|
@@ -581,16 +569,42 @@ class BisectingKMeans(BaseTransformer):
|
|
581
569
|
self._is_fitted = True
|
582
570
|
return output_result
|
583
571
|
|
572
|
+
|
573
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
574
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
575
|
+
""" Compute clustering and transform X to cluster-distance space
|
576
|
+
For more details on this function, see [sklearn.cluster.BisectingKMeans.fit_transform]
|
577
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.BisectingKMeans.html#sklearn.cluster.BisectingKMeans.fit_transform)
|
578
|
+
|
584
579
|
|
585
|
-
|
586
|
-
|
587
|
-
|
580
|
+
Raises:
|
581
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
582
|
+
|
583
|
+
Args:
|
584
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
585
|
+
Snowpark or Pandas DataFrame.
|
586
|
+
output_cols_prefix: Prefix for the response columns
|
588
587
|
Returns:
|
589
588
|
Transformed dataset.
|
590
589
|
"""
|
591
|
-
self.
|
592
|
-
|
593
|
-
|
590
|
+
self._infer_input_output_cols(dataset)
|
591
|
+
super()._check_dataset_type(dataset)
|
592
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
593
|
+
estimator=self._sklearn_object,
|
594
|
+
dataset=dataset,
|
595
|
+
input_cols=self.input_cols,
|
596
|
+
label_cols=self.label_cols,
|
597
|
+
sample_weight_col=self.sample_weight_col,
|
598
|
+
autogenerated=self._autogenerated,
|
599
|
+
subproject=_SUBPROJECT,
|
600
|
+
)
|
601
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
602
|
+
drop_input_cols=self._drop_input_cols,
|
603
|
+
expected_output_cols_list=self.output_cols,
|
604
|
+
)
|
605
|
+
self._sklearn_object = fitted_estimator
|
606
|
+
self._is_fitted = True
|
607
|
+
return output_result
|
594
608
|
|
595
609
|
|
596
610
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -681,10 +695,8 @@ class BisectingKMeans(BaseTransformer):
|
|
681
695
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
682
696
|
|
683
697
|
if isinstance(dataset, DataFrame):
|
684
|
-
self.
|
685
|
-
|
686
|
-
inference_method=inference_method,
|
687
|
-
)
|
698
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
699
|
+
self._deps = self._get_dependencies()
|
688
700
|
assert isinstance(
|
689
701
|
dataset._session, Session
|
690
702
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -749,10 +761,8 @@ class BisectingKMeans(BaseTransformer):
|
|
749
761
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
750
762
|
|
751
763
|
if isinstance(dataset, DataFrame):
|
752
|
-
self.
|
753
|
-
|
754
|
-
inference_method=inference_method,
|
755
|
-
)
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
756
766
|
assert isinstance(
|
757
767
|
dataset._session, Session
|
758
768
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -814,10 +824,8 @@ class BisectingKMeans(BaseTransformer):
|
|
814
824
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
815
825
|
|
816
826
|
if isinstance(dataset, DataFrame):
|
817
|
-
self.
|
818
|
-
|
819
|
-
inference_method=inference_method,
|
820
|
-
)
|
827
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
828
|
+
self._deps = self._get_dependencies()
|
821
829
|
assert isinstance(
|
822
830
|
dataset._session, Session
|
823
831
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -883,10 +891,8 @@ class BisectingKMeans(BaseTransformer):
|
|
883
891
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
884
892
|
|
885
893
|
if isinstance(dataset, DataFrame):
|
886
|
-
self.
|
887
|
-
|
888
|
-
inference_method=inference_method,
|
889
|
-
)
|
894
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
895
|
+
self._deps = self._get_dependencies()
|
890
896
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
891
897
|
transform_kwargs = dict(
|
892
898
|
session=dataset._session,
|
@@ -950,17 +956,15 @@ class BisectingKMeans(BaseTransformer):
|
|
950
956
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
951
957
|
|
952
958
|
if isinstance(dataset, DataFrame):
|
953
|
-
self.
|
954
|
-
|
955
|
-
inference_method="score",
|
956
|
-
)
|
959
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
960
|
+
self._deps = self._get_dependencies()
|
957
961
|
selected_cols = self._get_active_columns()
|
958
962
|
if len(selected_cols) > 0:
|
959
963
|
dataset = dataset.select(selected_cols)
|
960
964
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
961
965
|
transform_kwargs = dict(
|
962
966
|
session=dataset._session,
|
963
|
-
dependencies=
|
967
|
+
dependencies=self._deps,
|
964
968
|
score_sproc_imports=['sklearn'],
|
965
969
|
)
|
966
970
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1025,11 +1029,8 @@ class BisectingKMeans(BaseTransformer):
|
|
1025
1029
|
|
1026
1030
|
if isinstance(dataset, DataFrame):
|
1027
1031
|
|
1028
|
-
self.
|
1029
|
-
|
1030
|
-
inference_method=inference_method,
|
1031
|
-
|
1032
|
-
)
|
1032
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1033
|
+
self._deps = self._get_dependencies()
|
1033
1034
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1034
1035
|
transform_kwargs = dict(
|
1035
1036
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class DBSCAN(BaseTransformer):
|
70
64
|
r"""Perform DBSCAN clustering from vector array or distance matrix
|
71
65
|
For more details on this class, see [sklearn.cluster.DBSCAN]
|
@@ -311,20 +305,17 @@ class DBSCAN(BaseTransformer):
|
|
311
305
|
self,
|
312
306
|
dataset: DataFrame,
|
313
307
|
inference_method: str,
|
314
|
-
) ->
|
315
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
316
|
-
return the available package that exists in the snowflake anaconda channel
|
308
|
+
) -> None:
|
309
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
317
310
|
|
318
311
|
Args:
|
319
312
|
dataset: snowpark dataframe
|
320
313
|
inference_method: the inference method such as predict, score...
|
321
|
-
|
314
|
+
|
322
315
|
Raises:
|
323
316
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
324
317
|
SnowflakeMLException: If the session is None, raise error
|
325
318
|
|
326
|
-
Returns:
|
327
|
-
A list of available package that exists in the snowflake anaconda channel
|
328
319
|
"""
|
329
320
|
if not self._is_fitted:
|
330
321
|
raise exceptions.SnowflakeMLException(
|
@@ -342,9 +333,7 @@ class DBSCAN(BaseTransformer):
|
|
342
333
|
"Session must not specified for snowpark dataset."
|
343
334
|
),
|
344
335
|
)
|
345
|
-
|
346
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
347
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
336
|
+
|
348
337
|
|
349
338
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
350
339
|
@telemetry.send_api_usage_telemetry(
|
@@ -390,7 +379,8 @@ class DBSCAN(BaseTransformer):
|
|
390
379
|
|
391
380
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
392
381
|
|
393
|
-
self.
|
382
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
383
|
+
self._deps = self._get_dependencies()
|
394
384
|
assert isinstance(
|
395
385
|
dataset._session, Session
|
396
386
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -473,10 +463,8 @@ class DBSCAN(BaseTransformer):
|
|
473
463
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
474
464
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
475
465
|
|
476
|
-
self.
|
477
|
-
|
478
|
-
inference_method=inference_method,
|
479
|
-
)
|
466
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
467
|
+
self._deps = self._get_dependencies()
|
480
468
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
481
469
|
|
482
470
|
transform_kwargs = dict(
|
@@ -545,16 +533,40 @@ class DBSCAN(BaseTransformer):
|
|
545
533
|
self._is_fitted = True
|
546
534
|
return output_result
|
547
535
|
|
536
|
+
|
537
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
538
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
539
|
+
""" Method not supported for this class.
|
540
|
+
|
548
541
|
|
549
|
-
|
550
|
-
|
551
|
-
|
542
|
+
Raises:
|
543
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
544
|
+
|
545
|
+
Args:
|
546
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
547
|
+
Snowpark or Pandas DataFrame.
|
548
|
+
output_cols_prefix: Prefix for the response columns
|
552
549
|
Returns:
|
553
550
|
Transformed dataset.
|
554
551
|
"""
|
555
|
-
self.
|
556
|
-
|
557
|
-
|
552
|
+
self._infer_input_output_cols(dataset)
|
553
|
+
super()._check_dataset_type(dataset)
|
554
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
555
|
+
estimator=self._sklearn_object,
|
556
|
+
dataset=dataset,
|
557
|
+
input_cols=self.input_cols,
|
558
|
+
label_cols=self.label_cols,
|
559
|
+
sample_weight_col=self.sample_weight_col,
|
560
|
+
autogenerated=self._autogenerated,
|
561
|
+
subproject=_SUBPROJECT,
|
562
|
+
)
|
563
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
564
|
+
drop_input_cols=self._drop_input_cols,
|
565
|
+
expected_output_cols_list=self.output_cols,
|
566
|
+
)
|
567
|
+
self._sklearn_object = fitted_estimator
|
568
|
+
self._is_fitted = True
|
569
|
+
return output_result
|
558
570
|
|
559
571
|
|
560
572
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -645,10 +657,8 @@ class DBSCAN(BaseTransformer):
|
|
645
657
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
646
658
|
|
647
659
|
if isinstance(dataset, DataFrame):
|
648
|
-
self.
|
649
|
-
|
650
|
-
inference_method=inference_method,
|
651
|
-
)
|
660
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
661
|
+
self._deps = self._get_dependencies()
|
652
662
|
assert isinstance(
|
653
663
|
dataset._session, Session
|
654
664
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -713,10 +723,8 @@ class DBSCAN(BaseTransformer):
|
|
713
723
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
714
724
|
|
715
725
|
if isinstance(dataset, DataFrame):
|
716
|
-
self.
|
717
|
-
|
718
|
-
inference_method=inference_method,
|
719
|
-
)
|
726
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
727
|
+
self._deps = self._get_dependencies()
|
720
728
|
assert isinstance(
|
721
729
|
dataset._session, Session
|
722
730
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -778,10 +786,8 @@ class DBSCAN(BaseTransformer):
|
|
778
786
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
779
787
|
|
780
788
|
if isinstance(dataset, DataFrame):
|
781
|
-
self.
|
782
|
-
|
783
|
-
inference_method=inference_method,
|
784
|
-
)
|
789
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
790
|
+
self._deps = self._get_dependencies()
|
785
791
|
assert isinstance(
|
786
792
|
dataset._session, Session
|
787
793
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -847,10 +853,8 @@ class DBSCAN(BaseTransformer):
|
|
847
853
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
848
854
|
|
849
855
|
if isinstance(dataset, DataFrame):
|
850
|
-
self.
|
851
|
-
|
852
|
-
inference_method=inference_method,
|
853
|
-
)
|
856
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
857
|
+
self._deps = self._get_dependencies()
|
854
858
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
855
859
|
transform_kwargs = dict(
|
856
860
|
session=dataset._session,
|
@@ -912,17 +916,15 @@ class DBSCAN(BaseTransformer):
|
|
912
916
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
913
917
|
|
914
918
|
if isinstance(dataset, DataFrame):
|
915
|
-
self.
|
916
|
-
|
917
|
-
inference_method="score",
|
918
|
-
)
|
919
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
920
|
+
self._deps = self._get_dependencies()
|
919
921
|
selected_cols = self._get_active_columns()
|
920
922
|
if len(selected_cols) > 0:
|
921
923
|
dataset = dataset.select(selected_cols)
|
922
924
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
923
925
|
transform_kwargs = dict(
|
924
926
|
session=dataset._session,
|
925
|
-
dependencies=
|
927
|
+
dependencies=self._deps,
|
926
928
|
score_sproc_imports=['sklearn'],
|
927
929
|
)
|
928
930
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -987,11 +989,8 @@ class DBSCAN(BaseTransformer):
|
|
987
989
|
|
988
990
|
if isinstance(dataset, DataFrame):
|
989
991
|
|
990
|
-
self.
|
991
|
-
|
992
|
-
inference_method=inference_method,
|
993
|
-
|
994
|
-
)
|
992
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
993
|
+
self._deps = self._get_dependencies()
|
995
994
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
996
995
|
transform_kwargs = dict(
|
997
996
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class FeatureAgglomeration(BaseTransformer):
|
70
64
|
r"""Agglomerate features
|
71
65
|
For more details on this class, see [sklearn.cluster.FeatureAgglomeration]
|
@@ -343,20 +337,17 @@ class FeatureAgglomeration(BaseTransformer):
|
|
343
337
|
self,
|
344
338
|
dataset: DataFrame,
|
345
339
|
inference_method: str,
|
346
|
-
) ->
|
347
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
348
|
-
return the available package that exists in the snowflake anaconda channel
|
340
|
+
) -> None:
|
341
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
349
342
|
|
350
343
|
Args:
|
351
344
|
dataset: snowpark dataframe
|
352
345
|
inference_method: the inference method such as predict, score...
|
353
|
-
|
346
|
+
|
354
347
|
Raises:
|
355
348
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
356
349
|
SnowflakeMLException: If the session is None, raise error
|
357
350
|
|
358
|
-
Returns:
|
359
|
-
A list of available package that exists in the snowflake anaconda channel
|
360
351
|
"""
|
361
352
|
if not self._is_fitted:
|
362
353
|
raise exceptions.SnowflakeMLException(
|
@@ -374,9 +365,7 @@ class FeatureAgglomeration(BaseTransformer):
|
|
374
365
|
"Session must not specified for snowpark dataset."
|
375
366
|
),
|
376
367
|
)
|
377
|
-
|
378
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
379
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
368
|
+
|
380
369
|
|
381
370
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
382
371
|
@telemetry.send_api_usage_telemetry(
|
@@ -422,7 +411,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
422
411
|
|
423
412
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
424
413
|
|
425
|
-
self.
|
414
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
415
|
+
self._deps = self._get_dependencies()
|
426
416
|
assert isinstance(
|
427
417
|
dataset._session, Session
|
428
418
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -507,10 +497,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
507
497
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
508
498
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
509
499
|
|
510
|
-
self.
|
511
|
-
|
512
|
-
inference_method=inference_method,
|
513
|
-
)
|
500
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
501
|
+
self._deps = self._get_dependencies()
|
514
502
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
515
503
|
|
516
504
|
transform_kwargs = dict(
|
@@ -579,16 +567,42 @@ class FeatureAgglomeration(BaseTransformer):
|
|
579
567
|
self._is_fitted = True
|
580
568
|
return output_result
|
581
569
|
|
570
|
+
|
571
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
572
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
573
|
+
""" Fit to data, then transform it
|
574
|
+
For more details on this function, see [sklearn.cluster.FeatureAgglomeration.fit_transform]
|
575
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.FeatureAgglomeration.html#sklearn.cluster.FeatureAgglomeration.fit_transform)
|
576
|
+
|
582
577
|
|
583
|
-
|
584
|
-
|
585
|
-
|
578
|
+
Raises:
|
579
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
583
|
+
Snowpark or Pandas DataFrame.
|
584
|
+
output_cols_prefix: Prefix for the response columns
|
586
585
|
Returns:
|
587
586
|
Transformed dataset.
|
588
587
|
"""
|
589
|
-
self.
|
590
|
-
|
591
|
-
|
588
|
+
self._infer_input_output_cols(dataset)
|
589
|
+
super()._check_dataset_type(dataset)
|
590
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
591
|
+
estimator=self._sklearn_object,
|
592
|
+
dataset=dataset,
|
593
|
+
input_cols=self.input_cols,
|
594
|
+
label_cols=self.label_cols,
|
595
|
+
sample_weight_col=self.sample_weight_col,
|
596
|
+
autogenerated=self._autogenerated,
|
597
|
+
subproject=_SUBPROJECT,
|
598
|
+
)
|
599
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
600
|
+
drop_input_cols=self._drop_input_cols,
|
601
|
+
expected_output_cols_list=self.output_cols,
|
602
|
+
)
|
603
|
+
self._sklearn_object = fitted_estimator
|
604
|
+
self._is_fitted = True
|
605
|
+
return output_result
|
592
606
|
|
593
607
|
|
594
608
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -679,10 +693,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
679
693
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
680
694
|
|
681
695
|
if isinstance(dataset, DataFrame):
|
682
|
-
self.
|
683
|
-
|
684
|
-
inference_method=inference_method,
|
685
|
-
)
|
696
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
697
|
+
self._deps = self._get_dependencies()
|
686
698
|
assert isinstance(
|
687
699
|
dataset._session, Session
|
688
700
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -747,10 +759,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
747
759
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
748
760
|
|
749
761
|
if isinstance(dataset, DataFrame):
|
750
|
-
self.
|
751
|
-
|
752
|
-
inference_method=inference_method,
|
753
|
-
)
|
762
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
763
|
+
self._deps = self._get_dependencies()
|
754
764
|
assert isinstance(
|
755
765
|
dataset._session, Session
|
756
766
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -812,10 +822,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
812
822
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
813
823
|
|
814
824
|
if isinstance(dataset, DataFrame):
|
815
|
-
self.
|
816
|
-
|
817
|
-
inference_method=inference_method,
|
818
|
-
)
|
825
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
826
|
+
self._deps = self._get_dependencies()
|
819
827
|
assert isinstance(
|
820
828
|
dataset._session, Session
|
821
829
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -881,10 +889,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
881
889
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
882
890
|
|
883
891
|
if isinstance(dataset, DataFrame):
|
884
|
-
self.
|
885
|
-
|
886
|
-
inference_method=inference_method,
|
887
|
-
)
|
892
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
893
|
+
self._deps = self._get_dependencies()
|
888
894
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
889
895
|
transform_kwargs = dict(
|
890
896
|
session=dataset._session,
|
@@ -946,17 +952,15 @@ class FeatureAgglomeration(BaseTransformer):
|
|
946
952
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
947
953
|
|
948
954
|
if isinstance(dataset, DataFrame):
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method="score",
|
952
|
-
)
|
955
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
956
|
+
self._deps = self._get_dependencies()
|
953
957
|
selected_cols = self._get_active_columns()
|
954
958
|
if len(selected_cols) > 0:
|
955
959
|
dataset = dataset.select(selected_cols)
|
956
960
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
957
961
|
transform_kwargs = dict(
|
958
962
|
session=dataset._session,
|
959
|
-
dependencies=
|
963
|
+
dependencies=self._deps,
|
960
964
|
score_sproc_imports=['sklearn'],
|
961
965
|
)
|
962
966
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1021,11 +1025,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
1021
1025
|
|
1022
1026
|
if isinstance(dataset, DataFrame):
|
1023
1027
|
|
1024
|
-
self.
|
1025
|
-
|
1026
|
-
inference_method=inference_method,
|
1027
|
-
|
1028
|
-
)
|
1028
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1029
|
+
self._deps = self._get_dependencies()
|
1029
1030
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1030
1031
|
transform_kwargs = dict(
|
1031
1032
|
session = dataset._session,
|