snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class AffinityPropagation(BaseTransformer):
|
70
64
|
r"""Perform Affinity Propagation Clustering of data
|
71
65
|
For more details on this class, see [sklearn.cluster.AffinityPropagation]
|
@@ -303,20 +297,17 @@ class AffinityPropagation(BaseTransformer):
|
|
303
297
|
self,
|
304
298
|
dataset: DataFrame,
|
305
299
|
inference_method: str,
|
306
|
-
) ->
|
307
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
308
|
-
return the available package that exists in the snowflake anaconda channel
|
300
|
+
) -> None:
|
301
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
309
302
|
|
310
303
|
Args:
|
311
304
|
dataset: snowpark dataframe
|
312
305
|
inference_method: the inference method such as predict, score...
|
313
|
-
|
306
|
+
|
314
307
|
Raises:
|
315
308
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
316
309
|
SnowflakeMLException: If the session is None, raise error
|
317
310
|
|
318
|
-
Returns:
|
319
|
-
A list of available package that exists in the snowflake anaconda channel
|
320
311
|
"""
|
321
312
|
if not self._is_fitted:
|
322
313
|
raise exceptions.SnowflakeMLException(
|
@@ -334,9 +325,7 @@ class AffinityPropagation(BaseTransformer):
|
|
334
325
|
"Session must not specified for snowpark dataset."
|
335
326
|
),
|
336
327
|
)
|
337
|
-
|
338
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
339
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
328
|
+
|
340
329
|
|
341
330
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
342
331
|
@telemetry.send_api_usage_telemetry(
|
@@ -384,7 +373,8 @@ class AffinityPropagation(BaseTransformer):
|
|
384
373
|
|
385
374
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
386
375
|
|
387
|
-
self.
|
376
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
377
|
+
self._deps = self._get_dependencies()
|
388
378
|
assert isinstance(
|
389
379
|
dataset._session, Session
|
390
380
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -467,10 +457,8 @@ class AffinityPropagation(BaseTransformer):
|
|
467
457
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
468
458
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
469
459
|
|
470
|
-
self.
|
471
|
-
|
472
|
-
inference_method=inference_method,
|
473
|
-
)
|
460
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
461
|
+
self._deps = self._get_dependencies()
|
474
462
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
475
463
|
|
476
464
|
transform_kwargs = dict(
|
@@ -539,16 +527,40 @@ class AffinityPropagation(BaseTransformer):
|
|
539
527
|
self._is_fitted = True
|
540
528
|
return output_result
|
541
529
|
|
530
|
+
|
531
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
532
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
533
|
+
""" Method not supported for this class.
|
534
|
+
|
542
535
|
|
543
|
-
|
544
|
-
|
545
|
-
|
536
|
+
Raises:
|
537
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
541
|
+
Snowpark or Pandas DataFrame.
|
542
|
+
output_cols_prefix: Prefix for the response columns
|
546
543
|
Returns:
|
547
544
|
Transformed dataset.
|
548
545
|
"""
|
549
|
-
self.
|
550
|
-
|
551
|
-
|
546
|
+
self._infer_input_output_cols(dataset)
|
547
|
+
super()._check_dataset_type(dataset)
|
548
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
549
|
+
estimator=self._sklearn_object,
|
550
|
+
dataset=dataset,
|
551
|
+
input_cols=self.input_cols,
|
552
|
+
label_cols=self.label_cols,
|
553
|
+
sample_weight_col=self.sample_weight_col,
|
554
|
+
autogenerated=self._autogenerated,
|
555
|
+
subproject=_SUBPROJECT,
|
556
|
+
)
|
557
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
558
|
+
drop_input_cols=self._drop_input_cols,
|
559
|
+
expected_output_cols_list=self.output_cols,
|
560
|
+
)
|
561
|
+
self._sklearn_object = fitted_estimator
|
562
|
+
self._is_fitted = True
|
563
|
+
return output_result
|
552
564
|
|
553
565
|
|
554
566
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -639,10 +651,8 @@ class AffinityPropagation(BaseTransformer):
|
|
639
651
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
640
652
|
|
641
653
|
if isinstance(dataset, DataFrame):
|
642
|
-
self.
|
643
|
-
|
644
|
-
inference_method=inference_method,
|
645
|
-
)
|
654
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
655
|
+
self._deps = self._get_dependencies()
|
646
656
|
assert isinstance(
|
647
657
|
dataset._session, Session
|
648
658
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -707,10 +717,8 @@ class AffinityPropagation(BaseTransformer):
|
|
707
717
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
708
718
|
|
709
719
|
if isinstance(dataset, DataFrame):
|
710
|
-
self.
|
711
|
-
|
712
|
-
inference_method=inference_method,
|
713
|
-
)
|
720
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
721
|
+
self._deps = self._get_dependencies()
|
714
722
|
assert isinstance(
|
715
723
|
dataset._session, Session
|
716
724
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -772,10 +780,8 @@ class AffinityPropagation(BaseTransformer):
|
|
772
780
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
773
781
|
|
774
782
|
if isinstance(dataset, DataFrame):
|
775
|
-
self.
|
776
|
-
|
777
|
-
inference_method=inference_method,
|
778
|
-
)
|
783
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
784
|
+
self._deps = self._get_dependencies()
|
779
785
|
assert isinstance(
|
780
786
|
dataset._session, Session
|
781
787
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -841,10 +847,8 @@ class AffinityPropagation(BaseTransformer):
|
|
841
847
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
842
848
|
|
843
849
|
if isinstance(dataset, DataFrame):
|
844
|
-
self.
|
845
|
-
|
846
|
-
inference_method=inference_method,
|
847
|
-
)
|
850
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
851
|
+
self._deps = self._get_dependencies()
|
848
852
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
849
853
|
transform_kwargs = dict(
|
850
854
|
session=dataset._session,
|
@@ -906,17 +910,15 @@ class AffinityPropagation(BaseTransformer):
|
|
906
910
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
907
911
|
|
908
912
|
if isinstance(dataset, DataFrame):
|
909
|
-
self.
|
910
|
-
|
911
|
-
inference_method="score",
|
912
|
-
)
|
913
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
914
|
+
self._deps = self._get_dependencies()
|
913
915
|
selected_cols = self._get_active_columns()
|
914
916
|
if len(selected_cols) > 0:
|
915
917
|
dataset = dataset.select(selected_cols)
|
916
918
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
917
919
|
transform_kwargs = dict(
|
918
920
|
session=dataset._session,
|
919
|
-
dependencies=
|
921
|
+
dependencies=self._deps,
|
920
922
|
score_sproc_imports=['sklearn'],
|
921
923
|
)
|
922
924
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -981,11 +983,8 @@ class AffinityPropagation(BaseTransformer):
|
|
981
983
|
|
982
984
|
if isinstance(dataset, DataFrame):
|
983
985
|
|
984
|
-
self.
|
985
|
-
|
986
|
-
inference_method=inference_method,
|
987
|
-
|
988
|
-
)
|
986
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
987
|
+
self._deps = self._get_dependencies()
|
989
988
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
990
989
|
transform_kwargs = dict(
|
991
990
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class AgglomerativeClustering(BaseTransformer):
|
70
64
|
r"""Agglomerative Clustering
|
71
65
|
For more details on this class, see [sklearn.cluster.AgglomerativeClustering]
|
@@ -336,20 +330,17 @@ class AgglomerativeClustering(BaseTransformer):
|
|
336
330
|
self,
|
337
331
|
dataset: DataFrame,
|
338
332
|
inference_method: str,
|
339
|
-
) ->
|
340
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
341
|
-
return the available package that exists in the snowflake anaconda channel
|
333
|
+
) -> None:
|
334
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
342
335
|
|
343
336
|
Args:
|
344
337
|
dataset: snowpark dataframe
|
345
338
|
inference_method: the inference method such as predict, score...
|
346
|
-
|
339
|
+
|
347
340
|
Raises:
|
348
341
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
349
342
|
SnowflakeMLException: If the session is None, raise error
|
350
343
|
|
351
|
-
Returns:
|
352
|
-
A list of available package that exists in the snowflake anaconda channel
|
353
344
|
"""
|
354
345
|
if not self._is_fitted:
|
355
346
|
raise exceptions.SnowflakeMLException(
|
@@ -367,9 +358,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
367
358
|
"Session must not specified for snowpark dataset."
|
368
359
|
),
|
369
360
|
)
|
370
|
-
|
371
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
372
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
361
|
+
|
373
362
|
|
374
363
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
375
364
|
@telemetry.send_api_usage_telemetry(
|
@@ -415,7 +404,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
415
404
|
|
416
405
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
417
406
|
|
418
|
-
self.
|
407
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
408
|
+
self._deps = self._get_dependencies()
|
419
409
|
assert isinstance(
|
420
410
|
dataset._session, Session
|
421
411
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -498,10 +488,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
498
488
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
499
489
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
500
490
|
|
501
|
-
self.
|
502
|
-
|
503
|
-
inference_method=inference_method,
|
504
|
-
)
|
491
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
492
|
+
self._deps = self._get_dependencies()
|
505
493
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
506
494
|
|
507
495
|
transform_kwargs = dict(
|
@@ -570,16 +558,40 @@ class AgglomerativeClustering(BaseTransformer):
|
|
570
558
|
self._is_fitted = True
|
571
559
|
return output_result
|
572
560
|
|
561
|
+
|
562
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
563
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
564
|
+
""" Method not supported for this class.
|
565
|
+
|
573
566
|
|
574
|
-
|
575
|
-
|
576
|
-
|
567
|
+
Raises:
|
568
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
569
|
+
|
570
|
+
Args:
|
571
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
572
|
+
Snowpark or Pandas DataFrame.
|
573
|
+
output_cols_prefix: Prefix for the response columns
|
577
574
|
Returns:
|
578
575
|
Transformed dataset.
|
579
576
|
"""
|
580
|
-
self.
|
581
|
-
|
582
|
-
|
577
|
+
self._infer_input_output_cols(dataset)
|
578
|
+
super()._check_dataset_type(dataset)
|
579
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
580
|
+
estimator=self._sklearn_object,
|
581
|
+
dataset=dataset,
|
582
|
+
input_cols=self.input_cols,
|
583
|
+
label_cols=self.label_cols,
|
584
|
+
sample_weight_col=self.sample_weight_col,
|
585
|
+
autogenerated=self._autogenerated,
|
586
|
+
subproject=_SUBPROJECT,
|
587
|
+
)
|
588
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
589
|
+
drop_input_cols=self._drop_input_cols,
|
590
|
+
expected_output_cols_list=self.output_cols,
|
591
|
+
)
|
592
|
+
self._sklearn_object = fitted_estimator
|
593
|
+
self._is_fitted = True
|
594
|
+
return output_result
|
583
595
|
|
584
596
|
|
585
597
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -670,10 +682,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
670
682
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
671
683
|
|
672
684
|
if isinstance(dataset, DataFrame):
|
673
|
-
self.
|
674
|
-
|
675
|
-
inference_method=inference_method,
|
676
|
-
)
|
685
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
686
|
+
self._deps = self._get_dependencies()
|
677
687
|
assert isinstance(
|
678
688
|
dataset._session, Session
|
679
689
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -738,10 +748,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
738
748
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
739
749
|
|
740
750
|
if isinstance(dataset, DataFrame):
|
741
|
-
self.
|
742
|
-
|
743
|
-
inference_method=inference_method,
|
744
|
-
)
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
745
753
|
assert isinstance(
|
746
754
|
dataset._session, Session
|
747
755
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -803,10 +811,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
803
811
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
804
812
|
|
805
813
|
if isinstance(dataset, DataFrame):
|
806
|
-
self.
|
807
|
-
|
808
|
-
inference_method=inference_method,
|
809
|
-
)
|
814
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
815
|
+
self._deps = self._get_dependencies()
|
810
816
|
assert isinstance(
|
811
817
|
dataset._session, Session
|
812
818
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -872,10 +878,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
872
878
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
873
879
|
|
874
880
|
if isinstance(dataset, DataFrame):
|
875
|
-
self.
|
876
|
-
|
877
|
-
inference_method=inference_method,
|
878
|
-
)
|
881
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
882
|
+
self._deps = self._get_dependencies()
|
879
883
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
880
884
|
transform_kwargs = dict(
|
881
885
|
session=dataset._session,
|
@@ -937,17 +941,15 @@ class AgglomerativeClustering(BaseTransformer):
|
|
937
941
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
938
942
|
|
939
943
|
if isinstance(dataset, DataFrame):
|
940
|
-
self.
|
941
|
-
|
942
|
-
inference_method="score",
|
943
|
-
)
|
944
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
945
|
+
self._deps = self._get_dependencies()
|
944
946
|
selected_cols = self._get_active_columns()
|
945
947
|
if len(selected_cols) > 0:
|
946
948
|
dataset = dataset.select(selected_cols)
|
947
949
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
948
950
|
transform_kwargs = dict(
|
949
951
|
session=dataset._session,
|
950
|
-
dependencies=
|
952
|
+
dependencies=self._deps,
|
951
953
|
score_sproc_imports=['sklearn'],
|
952
954
|
)
|
953
955
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1012,11 +1014,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
1012
1014
|
|
1013
1015
|
if isinstance(dataset, DataFrame):
|
1014
1016
|
|
1015
|
-
self.
|
1016
|
-
|
1017
|
-
inference_method=inference_method,
|
1018
|
-
|
1019
|
-
)
|
1017
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1018
|
+
self._deps = self._get_dependencies()
|
1020
1019
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1021
1020
|
transform_kwargs = dict(
|
1022
1021
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class Birch(BaseTransformer):
|
70
64
|
r"""Implements the BIRCH clustering algorithm
|
71
65
|
For more details on this class, see [sklearn.cluster.Birch]
|
@@ -294,20 +288,17 @@ class Birch(BaseTransformer):
|
|
294
288
|
self,
|
295
289
|
dataset: DataFrame,
|
296
290
|
inference_method: str,
|
297
|
-
) ->
|
298
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
299
|
-
return the available package that exists in the snowflake anaconda channel
|
291
|
+
) -> None:
|
292
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
300
293
|
|
301
294
|
Args:
|
302
295
|
dataset: snowpark dataframe
|
303
296
|
inference_method: the inference method such as predict, score...
|
304
|
-
|
297
|
+
|
305
298
|
Raises:
|
306
299
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
307
300
|
SnowflakeMLException: If the session is None, raise error
|
308
301
|
|
309
|
-
Returns:
|
310
|
-
A list of available package that exists in the snowflake anaconda channel
|
311
302
|
"""
|
312
303
|
if not self._is_fitted:
|
313
304
|
raise exceptions.SnowflakeMLException(
|
@@ -325,9 +316,7 @@ class Birch(BaseTransformer):
|
|
325
316
|
"Session must not specified for snowpark dataset."
|
326
317
|
),
|
327
318
|
)
|
328
|
-
|
329
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
330
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
319
|
+
|
331
320
|
|
332
321
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
333
322
|
@telemetry.send_api_usage_telemetry(
|
@@ -375,7 +364,8 @@ class Birch(BaseTransformer):
|
|
375
364
|
|
376
365
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
377
366
|
|
378
|
-
self.
|
367
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
368
|
+
self._deps = self._get_dependencies()
|
379
369
|
assert isinstance(
|
380
370
|
dataset._session, Session
|
381
371
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -460,10 +450,8 @@ class Birch(BaseTransformer):
|
|
460
450
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
461
451
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
462
452
|
|
463
|
-
self.
|
464
|
-
|
465
|
-
inference_method=inference_method,
|
466
|
-
)
|
453
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
454
|
+
self._deps = self._get_dependencies()
|
467
455
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
468
456
|
|
469
457
|
transform_kwargs = dict(
|
@@ -532,16 +520,42 @@ class Birch(BaseTransformer):
|
|
532
520
|
self._is_fitted = True
|
533
521
|
return output_result
|
534
522
|
|
523
|
+
|
524
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
525
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
526
|
+
""" Fit to data, then transform it
|
527
|
+
For more details on this function, see [sklearn.cluster.Birch.fit_transform]
|
528
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch.fit_transform)
|
529
|
+
|
535
530
|
|
536
|
-
|
537
|
-
|
538
|
-
|
531
|
+
Raises:
|
532
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
533
|
+
|
534
|
+
Args:
|
535
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
536
|
+
Snowpark or Pandas DataFrame.
|
537
|
+
output_cols_prefix: Prefix for the response columns
|
539
538
|
Returns:
|
540
539
|
Transformed dataset.
|
541
540
|
"""
|
542
|
-
self.
|
543
|
-
|
544
|
-
|
541
|
+
self._infer_input_output_cols(dataset)
|
542
|
+
super()._check_dataset_type(dataset)
|
543
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
544
|
+
estimator=self._sklearn_object,
|
545
|
+
dataset=dataset,
|
546
|
+
input_cols=self.input_cols,
|
547
|
+
label_cols=self.label_cols,
|
548
|
+
sample_weight_col=self.sample_weight_col,
|
549
|
+
autogenerated=self._autogenerated,
|
550
|
+
subproject=_SUBPROJECT,
|
551
|
+
)
|
552
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
553
|
+
drop_input_cols=self._drop_input_cols,
|
554
|
+
expected_output_cols_list=self.output_cols,
|
555
|
+
)
|
556
|
+
self._sklearn_object = fitted_estimator
|
557
|
+
self._is_fitted = True
|
558
|
+
return output_result
|
545
559
|
|
546
560
|
|
547
561
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -632,10 +646,8 @@ class Birch(BaseTransformer):
|
|
632
646
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
633
647
|
|
634
648
|
if isinstance(dataset, DataFrame):
|
635
|
-
self.
|
636
|
-
|
637
|
-
inference_method=inference_method,
|
638
|
-
)
|
649
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
650
|
+
self._deps = self._get_dependencies()
|
639
651
|
assert isinstance(
|
640
652
|
dataset._session, Session
|
641
653
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -700,10 +712,8 @@ class Birch(BaseTransformer):
|
|
700
712
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
701
713
|
|
702
714
|
if isinstance(dataset, DataFrame):
|
703
|
-
self.
|
704
|
-
|
705
|
-
inference_method=inference_method,
|
706
|
-
)
|
715
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
716
|
+
self._deps = self._get_dependencies()
|
707
717
|
assert isinstance(
|
708
718
|
dataset._session, Session
|
709
719
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -765,10 +775,8 @@ class Birch(BaseTransformer):
|
|
765
775
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
766
776
|
|
767
777
|
if isinstance(dataset, DataFrame):
|
768
|
-
self.
|
769
|
-
|
770
|
-
inference_method=inference_method,
|
771
|
-
)
|
778
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
779
|
+
self._deps = self._get_dependencies()
|
772
780
|
assert isinstance(
|
773
781
|
dataset._session, Session
|
774
782
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -834,10 +842,8 @@ class Birch(BaseTransformer):
|
|
834
842
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
835
843
|
|
836
844
|
if isinstance(dataset, DataFrame):
|
837
|
-
self.
|
838
|
-
|
839
|
-
inference_method=inference_method,
|
840
|
-
)
|
845
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
846
|
+
self._deps = self._get_dependencies()
|
841
847
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
842
848
|
transform_kwargs = dict(
|
843
849
|
session=dataset._session,
|
@@ -899,17 +905,15 @@ class Birch(BaseTransformer):
|
|
899
905
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
900
906
|
|
901
907
|
if isinstance(dataset, DataFrame):
|
902
|
-
self.
|
903
|
-
|
904
|
-
inference_method="score",
|
905
|
-
)
|
908
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
909
|
+
self._deps = self._get_dependencies()
|
906
910
|
selected_cols = self._get_active_columns()
|
907
911
|
if len(selected_cols) > 0:
|
908
912
|
dataset = dataset.select(selected_cols)
|
909
913
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
910
914
|
transform_kwargs = dict(
|
911
915
|
session=dataset._session,
|
912
|
-
dependencies=
|
916
|
+
dependencies=self._deps,
|
913
917
|
score_sproc_imports=['sklearn'],
|
914
918
|
)
|
915
919
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -974,11 +978,8 @@ class Birch(BaseTransformer):
|
|
974
978
|
|
975
979
|
if isinstance(dataset, DataFrame):
|
976
980
|
|
977
|
-
self.
|
978
|
-
|
979
|
-
inference_method=inference_method,
|
980
|
-
|
981
|
-
)
|
981
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
982
|
+
self._deps = self._get_dependencies()
|
982
983
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
983
984
|
transform_kwargs = dict(
|
984
985
|
session = dataset._session,
|