snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class KMeans(BaseTransformer):
|
70
64
|
r"""K-Means clustering
|
71
65
|
For more details on this class, see [sklearn.cluster.KMeans]
|
@@ -338,20 +332,17 @@ class KMeans(BaseTransformer):
|
|
338
332
|
self,
|
339
333
|
dataset: DataFrame,
|
340
334
|
inference_method: str,
|
341
|
-
) ->
|
342
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
343
|
-
return the available package that exists in the snowflake anaconda channel
|
335
|
+
) -> None:
|
336
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
344
337
|
|
345
338
|
Args:
|
346
339
|
dataset: snowpark dataframe
|
347
340
|
inference_method: the inference method such as predict, score...
|
348
|
-
|
341
|
+
|
349
342
|
Raises:
|
350
343
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
351
344
|
SnowflakeMLException: If the session is None, raise error
|
352
345
|
|
353
|
-
Returns:
|
354
|
-
A list of available package that exists in the snowflake anaconda channel
|
355
346
|
"""
|
356
347
|
if not self._is_fitted:
|
357
348
|
raise exceptions.SnowflakeMLException(
|
@@ -369,9 +360,7 @@ class KMeans(BaseTransformer):
|
|
369
360
|
"Session must not specified for snowpark dataset."
|
370
361
|
),
|
371
362
|
)
|
372
|
-
|
373
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
374
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
363
|
+
|
375
364
|
|
376
365
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
377
366
|
@telemetry.send_api_usage_telemetry(
|
@@ -419,7 +408,8 @@ class KMeans(BaseTransformer):
|
|
419
408
|
|
420
409
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
421
410
|
|
422
|
-
self.
|
411
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
412
|
+
self._deps = self._get_dependencies()
|
423
413
|
assert isinstance(
|
424
414
|
dataset._session, Session
|
425
415
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -504,10 +494,8 @@ class KMeans(BaseTransformer):
|
|
504
494
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
505
495
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
506
496
|
|
507
|
-
self.
|
508
|
-
|
509
|
-
inference_method=inference_method,
|
510
|
-
)
|
497
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
498
|
+
self._deps = self._get_dependencies()
|
511
499
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
512
500
|
|
513
501
|
transform_kwargs = dict(
|
@@ -576,16 +564,42 @@ class KMeans(BaseTransformer):
|
|
576
564
|
self._is_fitted = True
|
577
565
|
return output_result
|
578
566
|
|
567
|
+
|
568
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
569
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
570
|
+
""" Compute clustering and transform X to cluster-distance space
|
571
|
+
For more details on this function, see [sklearn.cluster.KMeans.fit_transform]
|
572
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit_transform)
|
573
|
+
|
579
574
|
|
580
|
-
|
581
|
-
|
582
|
-
|
575
|
+
Raises:
|
576
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
577
|
+
|
578
|
+
Args:
|
579
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
580
|
+
Snowpark or Pandas DataFrame.
|
581
|
+
output_cols_prefix: Prefix for the response columns
|
583
582
|
Returns:
|
584
583
|
Transformed dataset.
|
585
584
|
"""
|
586
|
-
self.
|
587
|
-
|
588
|
-
|
585
|
+
self._infer_input_output_cols(dataset)
|
586
|
+
super()._check_dataset_type(dataset)
|
587
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
588
|
+
estimator=self._sklearn_object,
|
589
|
+
dataset=dataset,
|
590
|
+
input_cols=self.input_cols,
|
591
|
+
label_cols=self.label_cols,
|
592
|
+
sample_weight_col=self.sample_weight_col,
|
593
|
+
autogenerated=self._autogenerated,
|
594
|
+
subproject=_SUBPROJECT,
|
595
|
+
)
|
596
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
597
|
+
drop_input_cols=self._drop_input_cols,
|
598
|
+
expected_output_cols_list=self.output_cols,
|
599
|
+
)
|
600
|
+
self._sklearn_object = fitted_estimator
|
601
|
+
self._is_fitted = True
|
602
|
+
return output_result
|
589
603
|
|
590
604
|
|
591
605
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -676,10 +690,8 @@ class KMeans(BaseTransformer):
|
|
676
690
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
677
691
|
|
678
692
|
if isinstance(dataset, DataFrame):
|
679
|
-
self.
|
680
|
-
|
681
|
-
inference_method=inference_method,
|
682
|
-
)
|
693
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
694
|
+
self._deps = self._get_dependencies()
|
683
695
|
assert isinstance(
|
684
696
|
dataset._session, Session
|
685
697
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -744,10 +756,8 @@ class KMeans(BaseTransformer):
|
|
744
756
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
745
757
|
|
746
758
|
if isinstance(dataset, DataFrame):
|
747
|
-
self.
|
748
|
-
|
749
|
-
inference_method=inference_method,
|
750
|
-
)
|
759
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
760
|
+
self._deps = self._get_dependencies()
|
751
761
|
assert isinstance(
|
752
762
|
dataset._session, Session
|
753
763
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -809,10 +819,8 @@ class KMeans(BaseTransformer):
|
|
809
819
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
810
820
|
|
811
821
|
if isinstance(dataset, DataFrame):
|
812
|
-
self.
|
813
|
-
|
814
|
-
inference_method=inference_method,
|
815
|
-
)
|
822
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
823
|
+
self._deps = self._get_dependencies()
|
816
824
|
assert isinstance(
|
817
825
|
dataset._session, Session
|
818
826
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -878,10 +886,8 @@ class KMeans(BaseTransformer):
|
|
878
886
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
879
887
|
|
880
888
|
if isinstance(dataset, DataFrame):
|
881
|
-
self.
|
882
|
-
|
883
|
-
inference_method=inference_method,
|
884
|
-
)
|
889
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
890
|
+
self._deps = self._get_dependencies()
|
885
891
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
886
892
|
transform_kwargs = dict(
|
887
893
|
session=dataset._session,
|
@@ -945,17 +951,15 @@ class KMeans(BaseTransformer):
|
|
945
951
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
946
952
|
|
947
953
|
if isinstance(dataset, DataFrame):
|
948
|
-
self.
|
949
|
-
|
950
|
-
inference_method="score",
|
951
|
-
)
|
954
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
955
|
+
self._deps = self._get_dependencies()
|
952
956
|
selected_cols = self._get_active_columns()
|
953
957
|
if len(selected_cols) > 0:
|
954
958
|
dataset = dataset.select(selected_cols)
|
955
959
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
956
960
|
transform_kwargs = dict(
|
957
961
|
session=dataset._session,
|
958
|
-
dependencies=
|
962
|
+
dependencies=self._deps,
|
959
963
|
score_sproc_imports=['sklearn'],
|
960
964
|
)
|
961
965
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1020,11 +1024,8 @@ class KMeans(BaseTransformer):
|
|
1020
1024
|
|
1021
1025
|
if isinstance(dataset, DataFrame):
|
1022
1026
|
|
1023
|
-
self.
|
1024
|
-
|
1025
|
-
inference_method=inference_method,
|
1026
|
-
|
1027
|
-
)
|
1027
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1028
|
+
self._deps = self._get_dependencies()
|
1028
1029
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1029
1030
|
transform_kwargs = dict(
|
1030
1031
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MeanShift(BaseTransformer):
|
70
64
|
r"""Mean shift clustering using a flat kernel
|
71
65
|
For more details on this class, see [sklearn.cluster.MeanShift]
|
@@ -314,20 +308,17 @@ class MeanShift(BaseTransformer):
|
|
314
308
|
self,
|
315
309
|
dataset: DataFrame,
|
316
310
|
inference_method: str,
|
317
|
-
) ->
|
318
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
319
|
-
return the available package that exists in the snowflake anaconda channel
|
311
|
+
) -> None:
|
312
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
320
313
|
|
321
314
|
Args:
|
322
315
|
dataset: snowpark dataframe
|
323
316
|
inference_method: the inference method such as predict, score...
|
324
|
-
|
317
|
+
|
325
318
|
Raises:
|
326
319
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
327
320
|
SnowflakeMLException: If the session is None, raise error
|
328
321
|
|
329
|
-
Returns:
|
330
|
-
A list of available package that exists in the snowflake anaconda channel
|
331
322
|
"""
|
332
323
|
if not self._is_fitted:
|
333
324
|
raise exceptions.SnowflakeMLException(
|
@@ -345,9 +336,7 @@ class MeanShift(BaseTransformer):
|
|
345
336
|
"Session must not specified for snowpark dataset."
|
346
337
|
),
|
347
338
|
)
|
348
|
-
|
349
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
350
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
339
|
+
|
351
340
|
|
352
341
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
353
342
|
@telemetry.send_api_usage_telemetry(
|
@@ -395,7 +384,8 @@ class MeanShift(BaseTransformer):
|
|
395
384
|
|
396
385
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
397
386
|
|
398
|
-
self.
|
387
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
388
|
+
self._deps = self._get_dependencies()
|
399
389
|
assert isinstance(
|
400
390
|
dataset._session, Session
|
401
391
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -478,10 +468,8 @@ class MeanShift(BaseTransformer):
|
|
478
468
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
479
469
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
480
470
|
|
481
|
-
self.
|
482
|
-
|
483
|
-
inference_method=inference_method,
|
484
|
-
)
|
471
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
472
|
+
self._deps = self._get_dependencies()
|
485
473
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
486
474
|
|
487
475
|
transform_kwargs = dict(
|
@@ -550,16 +538,40 @@ class MeanShift(BaseTransformer):
|
|
550
538
|
self._is_fitted = True
|
551
539
|
return output_result
|
552
540
|
|
541
|
+
|
542
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
543
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
544
|
+
""" Method not supported for this class.
|
545
|
+
|
553
546
|
|
554
|
-
|
555
|
-
|
556
|
-
|
547
|
+
Raises:
|
548
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
549
|
+
|
550
|
+
Args:
|
551
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
552
|
+
Snowpark or Pandas DataFrame.
|
553
|
+
output_cols_prefix: Prefix for the response columns
|
557
554
|
Returns:
|
558
555
|
Transformed dataset.
|
559
556
|
"""
|
560
|
-
self.
|
561
|
-
|
562
|
-
|
557
|
+
self._infer_input_output_cols(dataset)
|
558
|
+
super()._check_dataset_type(dataset)
|
559
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
560
|
+
estimator=self._sklearn_object,
|
561
|
+
dataset=dataset,
|
562
|
+
input_cols=self.input_cols,
|
563
|
+
label_cols=self.label_cols,
|
564
|
+
sample_weight_col=self.sample_weight_col,
|
565
|
+
autogenerated=self._autogenerated,
|
566
|
+
subproject=_SUBPROJECT,
|
567
|
+
)
|
568
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
569
|
+
drop_input_cols=self._drop_input_cols,
|
570
|
+
expected_output_cols_list=self.output_cols,
|
571
|
+
)
|
572
|
+
self._sklearn_object = fitted_estimator
|
573
|
+
self._is_fitted = True
|
574
|
+
return output_result
|
563
575
|
|
564
576
|
|
565
577
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -650,10 +662,8 @@ class MeanShift(BaseTransformer):
|
|
650
662
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
651
663
|
|
652
664
|
if isinstance(dataset, DataFrame):
|
653
|
-
self.
|
654
|
-
|
655
|
-
inference_method=inference_method,
|
656
|
-
)
|
665
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
666
|
+
self._deps = self._get_dependencies()
|
657
667
|
assert isinstance(
|
658
668
|
dataset._session, Session
|
659
669
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -718,10 +728,8 @@ class MeanShift(BaseTransformer):
|
|
718
728
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
719
729
|
|
720
730
|
if isinstance(dataset, DataFrame):
|
721
|
-
self.
|
722
|
-
|
723
|
-
inference_method=inference_method,
|
724
|
-
)
|
731
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
732
|
+
self._deps = self._get_dependencies()
|
725
733
|
assert isinstance(
|
726
734
|
dataset._session, Session
|
727
735
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -783,10 +791,8 @@ class MeanShift(BaseTransformer):
|
|
783
791
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
784
792
|
|
785
793
|
if isinstance(dataset, DataFrame):
|
786
|
-
self.
|
787
|
-
|
788
|
-
inference_method=inference_method,
|
789
|
-
)
|
794
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
795
|
+
self._deps = self._get_dependencies()
|
790
796
|
assert isinstance(
|
791
797
|
dataset._session, Session
|
792
798
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -852,10 +858,8 @@ class MeanShift(BaseTransformer):
|
|
852
858
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
853
859
|
|
854
860
|
if isinstance(dataset, DataFrame):
|
855
|
-
self.
|
856
|
-
|
857
|
-
inference_method=inference_method,
|
858
|
-
)
|
861
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
862
|
+
self._deps = self._get_dependencies()
|
859
863
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
860
864
|
transform_kwargs = dict(
|
861
865
|
session=dataset._session,
|
@@ -917,17 +921,15 @@ class MeanShift(BaseTransformer):
|
|
917
921
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
918
922
|
|
919
923
|
if isinstance(dataset, DataFrame):
|
920
|
-
self.
|
921
|
-
|
922
|
-
inference_method="score",
|
923
|
-
)
|
924
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
925
|
+
self._deps = self._get_dependencies()
|
924
926
|
selected_cols = self._get_active_columns()
|
925
927
|
if len(selected_cols) > 0:
|
926
928
|
dataset = dataset.select(selected_cols)
|
927
929
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
928
930
|
transform_kwargs = dict(
|
929
931
|
session=dataset._session,
|
930
|
-
dependencies=
|
932
|
+
dependencies=self._deps,
|
931
933
|
score_sproc_imports=['sklearn'],
|
932
934
|
)
|
933
935
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -992,11 +994,8 @@ class MeanShift(BaseTransformer):
|
|
992
994
|
|
993
995
|
if isinstance(dataset, DataFrame):
|
994
996
|
|
995
|
-
self.
|
996
|
-
|
997
|
-
inference_method=inference_method,
|
998
|
-
|
999
|
-
)
|
997
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
998
|
+
self._deps = self._get_dependencies()
|
1000
999
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1001
1000
|
transform_kwargs = dict(
|
1002
1001
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MiniBatchKMeans(BaseTransformer):
|
70
64
|
r"""Mini-Batch K-Means clustering
|
71
65
|
For more details on this class, see [sklearn.cluster.MiniBatchKMeans]
|
@@ -364,20 +358,17 @@ class MiniBatchKMeans(BaseTransformer):
|
|
364
358
|
self,
|
365
359
|
dataset: DataFrame,
|
366
360
|
inference_method: str,
|
367
|
-
) ->
|
368
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
369
|
-
return the available package that exists in the snowflake anaconda channel
|
361
|
+
) -> None:
|
362
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
370
363
|
|
371
364
|
Args:
|
372
365
|
dataset: snowpark dataframe
|
373
366
|
inference_method: the inference method such as predict, score...
|
374
|
-
|
367
|
+
|
375
368
|
Raises:
|
376
369
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
377
370
|
SnowflakeMLException: If the session is None, raise error
|
378
371
|
|
379
|
-
Returns:
|
380
|
-
A list of available package that exists in the snowflake anaconda channel
|
381
372
|
"""
|
382
373
|
if not self._is_fitted:
|
383
374
|
raise exceptions.SnowflakeMLException(
|
@@ -395,9 +386,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
395
386
|
"Session must not specified for snowpark dataset."
|
396
387
|
),
|
397
388
|
)
|
398
|
-
|
399
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
400
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
389
|
+
|
401
390
|
|
402
391
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
403
392
|
@telemetry.send_api_usage_telemetry(
|
@@ -445,7 +434,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
445
434
|
|
446
435
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
447
436
|
|
448
|
-
self.
|
437
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
438
|
+
self._deps = self._get_dependencies()
|
449
439
|
assert isinstance(
|
450
440
|
dataset._session, Session
|
451
441
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -530,10 +520,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
530
520
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
531
521
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
532
522
|
|
533
|
-
self.
|
534
|
-
|
535
|
-
inference_method=inference_method,
|
536
|
-
)
|
523
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
524
|
+
self._deps = self._get_dependencies()
|
537
525
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
538
526
|
|
539
527
|
transform_kwargs = dict(
|
@@ -602,16 +590,42 @@ class MiniBatchKMeans(BaseTransformer):
|
|
602
590
|
self._is_fitted = True
|
603
591
|
return output_result
|
604
592
|
|
593
|
+
|
594
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
595
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
596
|
+
""" Compute clustering and transform X to cluster-distance space
|
597
|
+
For more details on this function, see [sklearn.cluster.MiniBatchKMeans.fit_transform]
|
598
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html#sklearn.cluster.MiniBatchKMeans.fit_transform)
|
599
|
+
|
605
600
|
|
606
|
-
|
607
|
-
|
608
|
-
|
601
|
+
Raises:
|
602
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
603
|
+
|
604
|
+
Args:
|
605
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
606
|
+
Snowpark or Pandas DataFrame.
|
607
|
+
output_cols_prefix: Prefix for the response columns
|
609
608
|
Returns:
|
610
609
|
Transformed dataset.
|
611
610
|
"""
|
612
|
-
self.
|
613
|
-
|
614
|
-
|
611
|
+
self._infer_input_output_cols(dataset)
|
612
|
+
super()._check_dataset_type(dataset)
|
613
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
614
|
+
estimator=self._sklearn_object,
|
615
|
+
dataset=dataset,
|
616
|
+
input_cols=self.input_cols,
|
617
|
+
label_cols=self.label_cols,
|
618
|
+
sample_weight_col=self.sample_weight_col,
|
619
|
+
autogenerated=self._autogenerated,
|
620
|
+
subproject=_SUBPROJECT,
|
621
|
+
)
|
622
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
623
|
+
drop_input_cols=self._drop_input_cols,
|
624
|
+
expected_output_cols_list=self.output_cols,
|
625
|
+
)
|
626
|
+
self._sklearn_object = fitted_estimator
|
627
|
+
self._is_fitted = True
|
628
|
+
return output_result
|
615
629
|
|
616
630
|
|
617
631
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -702,10 +716,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
702
716
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
703
717
|
|
704
718
|
if isinstance(dataset, DataFrame):
|
705
|
-
self.
|
706
|
-
|
707
|
-
inference_method=inference_method,
|
708
|
-
)
|
719
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
720
|
+
self._deps = self._get_dependencies()
|
709
721
|
assert isinstance(
|
710
722
|
dataset._session, Session
|
711
723
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -770,10 +782,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
770
782
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
771
783
|
|
772
784
|
if isinstance(dataset, DataFrame):
|
773
|
-
self.
|
774
|
-
|
775
|
-
inference_method=inference_method,
|
776
|
-
)
|
785
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
786
|
+
self._deps = self._get_dependencies()
|
777
787
|
assert isinstance(
|
778
788
|
dataset._session, Session
|
779
789
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -835,10 +845,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
835
845
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
836
846
|
|
837
847
|
if isinstance(dataset, DataFrame):
|
838
|
-
self.
|
839
|
-
|
840
|
-
inference_method=inference_method,
|
841
|
-
)
|
848
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
849
|
+
self._deps = self._get_dependencies()
|
842
850
|
assert isinstance(
|
843
851
|
dataset._session, Session
|
844
852
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -904,10 +912,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
904
912
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
905
913
|
|
906
914
|
if isinstance(dataset, DataFrame):
|
907
|
-
self.
|
908
|
-
|
909
|
-
inference_method=inference_method,
|
910
|
-
)
|
915
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
916
|
+
self._deps = self._get_dependencies()
|
911
917
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
912
918
|
transform_kwargs = dict(
|
913
919
|
session=dataset._session,
|
@@ -971,17 +977,15 @@ class MiniBatchKMeans(BaseTransformer):
|
|
971
977
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
972
978
|
|
973
979
|
if isinstance(dataset, DataFrame):
|
974
|
-
self.
|
975
|
-
|
976
|
-
inference_method="score",
|
977
|
-
)
|
980
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
981
|
+
self._deps = self._get_dependencies()
|
978
982
|
selected_cols = self._get_active_columns()
|
979
983
|
if len(selected_cols) > 0:
|
980
984
|
dataset = dataset.select(selected_cols)
|
981
985
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
982
986
|
transform_kwargs = dict(
|
983
987
|
session=dataset._session,
|
984
|
-
dependencies=
|
988
|
+
dependencies=self._deps,
|
985
989
|
score_sproc_imports=['sklearn'],
|
986
990
|
)
|
987
991
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1046,11 +1050,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
1046
1050
|
|
1047
1051
|
if isinstance(dataset, DataFrame):
|
1048
1052
|
|
1049
|
-
self.
|
1050
|
-
|
1051
|
-
inference_method=inference_method,
|
1052
|
-
|
1053
|
-
)
|
1053
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1054
|
+
self._deps = self._get_dependencies()
|
1054
1055
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1055
1056
|
transform_kwargs = dict(
|
1056
1057
|
session = dataset._session,
|