snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.manifold".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return True and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class TSNE(BaseTransformer):
|
70
64
|
r"""T-distributed Stochastic Neighbor Embedding
|
71
65
|
For more details on this class, see [sklearn.manifold.TSNE]
|
@@ -383,20 +377,17 @@ class TSNE(BaseTransformer):
|
|
383
377
|
self,
|
384
378
|
dataset: DataFrame,
|
385
379
|
inference_method: str,
|
386
|
-
) ->
|
387
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
388
|
-
return the available package that exists in the snowflake anaconda channel
|
380
|
+
) -> None:
|
381
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
389
382
|
|
390
383
|
Args:
|
391
384
|
dataset: snowpark dataframe
|
392
385
|
inference_method: the inference method such as predict, score...
|
393
|
-
|
386
|
+
|
394
387
|
Raises:
|
395
388
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
396
389
|
SnowflakeMLException: If the session is None, raise error
|
397
390
|
|
398
|
-
Returns:
|
399
|
-
A list of available package that exists in the snowflake anaconda channel
|
400
391
|
"""
|
401
392
|
if not self._is_fitted:
|
402
393
|
raise exceptions.SnowflakeMLException(
|
@@ -414,9 +405,7 @@ class TSNE(BaseTransformer):
|
|
414
405
|
"Session must not specified for snowpark dataset."
|
415
406
|
),
|
416
407
|
)
|
417
|
-
|
418
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
419
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
408
|
+
|
420
409
|
|
421
410
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
422
411
|
@telemetry.send_api_usage_telemetry(
|
@@ -462,7 +451,8 @@ class TSNE(BaseTransformer):
|
|
462
451
|
|
463
452
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
464
453
|
|
465
|
-
self.
|
454
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
455
|
+
self._deps = self._get_dependencies()
|
466
456
|
assert isinstance(
|
467
457
|
dataset._session, Session
|
468
458
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -545,10 +535,8 @@ class TSNE(BaseTransformer):
|
|
545
535
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
546
536
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
547
537
|
|
548
|
-
self.
|
549
|
-
|
550
|
-
inference_method=inference_method,
|
551
|
-
)
|
538
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
539
|
+
self._deps = self._get_dependencies()
|
552
540
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
553
541
|
|
554
542
|
transform_kwargs = dict(
|
@@ -615,16 +603,42 @@ class TSNE(BaseTransformer):
|
|
615
603
|
self._is_fitted = True
|
616
604
|
return output_result
|
617
605
|
|
606
|
+
|
607
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
608
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
609
|
+
""" Fit X into an embedded space and return that transformed output
|
610
|
+
For more details on this function, see [sklearn.manifold.TSNE.fit_transform]
|
611
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html#sklearn.manifold.TSNE.fit_transform)
|
612
|
+
|
618
613
|
|
619
|
-
|
620
|
-
|
621
|
-
|
614
|
+
Raises:
|
615
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
616
|
+
|
617
|
+
Args:
|
618
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
619
|
+
Snowpark or Pandas DataFrame.
|
620
|
+
output_cols_prefix: Prefix for the response columns
|
622
621
|
Returns:
|
623
622
|
Transformed dataset.
|
624
623
|
"""
|
625
|
-
self.
|
626
|
-
|
627
|
-
|
624
|
+
self._infer_input_output_cols(dataset)
|
625
|
+
super()._check_dataset_type(dataset)
|
626
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
627
|
+
estimator=self._sklearn_object,
|
628
|
+
dataset=dataset,
|
629
|
+
input_cols=self.input_cols,
|
630
|
+
label_cols=self.label_cols,
|
631
|
+
sample_weight_col=self.sample_weight_col,
|
632
|
+
autogenerated=self._autogenerated,
|
633
|
+
subproject=_SUBPROJECT,
|
634
|
+
)
|
635
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
636
|
+
drop_input_cols=self._drop_input_cols,
|
637
|
+
expected_output_cols_list=self.output_cols,
|
638
|
+
)
|
639
|
+
self._sklearn_object = fitted_estimator
|
640
|
+
self._is_fitted = True
|
641
|
+
return output_result
|
628
642
|
|
629
643
|
|
630
644
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -715,10 +729,8 @@ class TSNE(BaseTransformer):
|
|
715
729
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
716
730
|
|
717
731
|
if isinstance(dataset, DataFrame):
|
718
|
-
self.
|
719
|
-
|
720
|
-
inference_method=inference_method,
|
721
|
-
)
|
732
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
733
|
+
self._deps = self._get_dependencies()
|
722
734
|
assert isinstance(
|
723
735
|
dataset._session, Session
|
724
736
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -783,10 +795,8 @@ class TSNE(BaseTransformer):
|
|
783
795
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
784
796
|
|
785
797
|
if isinstance(dataset, DataFrame):
|
786
|
-
self.
|
787
|
-
|
788
|
-
inference_method=inference_method,
|
789
|
-
)
|
798
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
799
|
+
self._deps = self._get_dependencies()
|
790
800
|
assert isinstance(
|
791
801
|
dataset._session, Session
|
792
802
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -848,10 +858,8 @@ class TSNE(BaseTransformer):
|
|
848
858
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
849
859
|
|
850
860
|
if isinstance(dataset, DataFrame):
|
851
|
-
self.
|
852
|
-
|
853
|
-
inference_method=inference_method,
|
854
|
-
)
|
861
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
862
|
+
self._deps = self._get_dependencies()
|
855
863
|
assert isinstance(
|
856
864
|
dataset._session, Session
|
857
865
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -917,10 +925,8 @@ class TSNE(BaseTransformer):
|
|
917
925
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
918
926
|
|
919
927
|
if isinstance(dataset, DataFrame):
|
920
|
-
self.
|
921
|
-
|
922
|
-
inference_method=inference_method,
|
923
|
-
)
|
928
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
929
|
+
self._deps = self._get_dependencies()
|
924
930
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
925
931
|
transform_kwargs = dict(
|
926
932
|
session=dataset._session,
|
@@ -982,17 +988,15 @@ class TSNE(BaseTransformer):
|
|
982
988
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
983
989
|
|
984
990
|
if isinstance(dataset, DataFrame):
|
985
|
-
self.
|
986
|
-
|
987
|
-
inference_method="score",
|
988
|
-
)
|
991
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
992
|
+
self._deps = self._get_dependencies()
|
989
993
|
selected_cols = self._get_active_columns()
|
990
994
|
if len(selected_cols) > 0:
|
991
995
|
dataset = dataset.select(selected_cols)
|
992
996
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
993
997
|
transform_kwargs = dict(
|
994
998
|
session=dataset._session,
|
995
|
-
dependencies=
|
999
|
+
dependencies=self._deps,
|
996
1000
|
score_sproc_imports=['sklearn'],
|
997
1001
|
)
|
998
1002
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1057,11 +1061,8 @@ class TSNE(BaseTransformer):
|
|
1057
1061
|
|
1058
1062
|
if isinstance(dataset, DataFrame):
|
1059
1063
|
|
1060
|
-
self.
|
1061
|
-
|
1062
|
-
inference_method=inference_method,
|
1063
|
-
|
1064
|
-
)
|
1064
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1065
|
+
self._deps = self._get_dependencies()
|
1065
1066
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1066
1067
|
transform_kwargs = dict(
|
1067
1068
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class BayesianGaussianMixture(BaseTransformer):
|
70
64
|
r"""Variational Bayesian estimation of a Gaussian mixture
|
71
65
|
For more details on this class, see [sklearn.mixture.BayesianGaussianMixture]
|
@@ -386,20 +380,17 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
386
380
|
self,
|
387
381
|
dataset: DataFrame,
|
388
382
|
inference_method: str,
|
389
|
-
) ->
|
390
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
391
|
-
return the available package that exists in the snowflake anaconda channel
|
383
|
+
) -> None:
|
384
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
392
385
|
|
393
386
|
Args:
|
394
387
|
dataset: snowpark dataframe
|
395
388
|
inference_method: the inference method such as predict, score...
|
396
|
-
|
389
|
+
|
397
390
|
Raises:
|
398
391
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
399
392
|
SnowflakeMLException: If the session is None, raise error
|
400
393
|
|
401
|
-
Returns:
|
402
|
-
A list of available package that exists in the snowflake anaconda channel
|
403
394
|
"""
|
404
395
|
if not self._is_fitted:
|
405
396
|
raise exceptions.SnowflakeMLException(
|
@@ -417,9 +408,7 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
417
408
|
"Session must not specified for snowpark dataset."
|
418
409
|
),
|
419
410
|
)
|
420
|
-
|
421
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
422
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
411
|
+
|
423
412
|
|
424
413
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
425
414
|
@telemetry.send_api_usage_telemetry(
|
@@ -467,7 +456,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
467
456
|
|
468
457
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
469
458
|
|
470
|
-
self.
|
459
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
460
|
+
self._deps = self._get_dependencies()
|
471
461
|
assert isinstance(
|
472
462
|
dataset._session, Session
|
473
463
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -550,10 +540,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
550
540
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
551
541
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
552
542
|
|
553
|
-
self.
|
554
|
-
|
555
|
-
inference_method=inference_method,
|
556
|
-
)
|
543
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
544
|
+
self._deps = self._get_dependencies()
|
557
545
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
558
546
|
|
559
547
|
transform_kwargs = dict(
|
@@ -622,16 +610,40 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
622
610
|
self._is_fitted = True
|
623
611
|
return output_result
|
624
612
|
|
613
|
+
|
614
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
615
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
616
|
+
""" Method not supported for this class.
|
617
|
+
|
625
618
|
|
626
|
-
|
627
|
-
|
628
|
-
|
619
|
+
Raises:
|
620
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
621
|
+
|
622
|
+
Args:
|
623
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
624
|
+
Snowpark or Pandas DataFrame.
|
625
|
+
output_cols_prefix: Prefix for the response columns
|
629
626
|
Returns:
|
630
627
|
Transformed dataset.
|
631
628
|
"""
|
632
|
-
self.
|
633
|
-
|
634
|
-
|
629
|
+
self._infer_input_output_cols(dataset)
|
630
|
+
super()._check_dataset_type(dataset)
|
631
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
632
|
+
estimator=self._sklearn_object,
|
633
|
+
dataset=dataset,
|
634
|
+
input_cols=self.input_cols,
|
635
|
+
label_cols=self.label_cols,
|
636
|
+
sample_weight_col=self.sample_weight_col,
|
637
|
+
autogenerated=self._autogenerated,
|
638
|
+
subproject=_SUBPROJECT,
|
639
|
+
)
|
640
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
641
|
+
drop_input_cols=self._drop_input_cols,
|
642
|
+
expected_output_cols_list=self.output_cols,
|
643
|
+
)
|
644
|
+
self._sklearn_object = fitted_estimator
|
645
|
+
self._is_fitted = True
|
646
|
+
return output_result
|
635
647
|
|
636
648
|
|
637
649
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -724,10 +736,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
724
736
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
725
737
|
|
726
738
|
if isinstance(dataset, DataFrame):
|
727
|
-
self.
|
728
|
-
|
729
|
-
inference_method=inference_method,
|
730
|
-
)
|
739
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
740
|
+
self._deps = self._get_dependencies()
|
731
741
|
assert isinstance(
|
732
742
|
dataset._session, Session
|
733
743
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -794,10 +804,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
794
804
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
795
805
|
|
796
806
|
if isinstance(dataset, DataFrame):
|
797
|
-
self.
|
798
|
-
|
799
|
-
inference_method=inference_method,
|
800
|
-
)
|
807
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
808
|
+
self._deps = self._get_dependencies()
|
801
809
|
assert isinstance(
|
802
810
|
dataset._session, Session
|
803
811
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -859,10 +867,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
859
867
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
860
868
|
|
861
869
|
if isinstance(dataset, DataFrame):
|
862
|
-
self.
|
863
|
-
|
864
|
-
inference_method=inference_method,
|
865
|
-
)
|
870
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
871
|
+
self._deps = self._get_dependencies()
|
866
872
|
assert isinstance(
|
867
873
|
dataset._session, Session
|
868
874
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -930,10 +936,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
930
936
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
931
937
|
|
932
938
|
if isinstance(dataset, DataFrame):
|
933
|
-
self.
|
934
|
-
|
935
|
-
inference_method=inference_method,
|
936
|
-
)
|
939
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
940
|
+
self._deps = self._get_dependencies()
|
937
941
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
938
942
|
transform_kwargs = dict(
|
939
943
|
session=dataset._session,
|
@@ -997,17 +1001,15 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
997
1001
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
998
1002
|
|
999
1003
|
if isinstance(dataset, DataFrame):
|
1000
|
-
self.
|
1001
|
-
|
1002
|
-
inference_method="score",
|
1003
|
-
)
|
1004
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1005
|
+
self._deps = self._get_dependencies()
|
1004
1006
|
selected_cols = self._get_active_columns()
|
1005
1007
|
if len(selected_cols) > 0:
|
1006
1008
|
dataset = dataset.select(selected_cols)
|
1007
1009
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1008
1010
|
transform_kwargs = dict(
|
1009
1011
|
session=dataset._session,
|
1010
|
-
dependencies=
|
1012
|
+
dependencies=self._deps,
|
1011
1013
|
score_sproc_imports=['sklearn'],
|
1012
1014
|
)
|
1013
1015
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1072,11 +1074,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
1072
1074
|
|
1073
1075
|
if isinstance(dataset, DataFrame):
|
1074
1076
|
|
1075
|
-
self.
|
1076
|
-
|
1077
|
-
inference_method=inference_method,
|
1078
|
-
|
1079
|
-
)
|
1077
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1078
|
+
self._deps = self._get_dependencies()
|
1080
1079
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1081
1080
|
transform_kwargs = dict(
|
1082
1081
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklear
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class GaussianMixture(BaseTransformer):
|
70
64
|
r"""Gaussian Mixture
|
71
65
|
For more details on this class, see [sklearn.mixture.GaussianMixture]
|
@@ -359,20 +353,17 @@ class GaussianMixture(BaseTransformer):
|
|
359
353
|
self,
|
360
354
|
dataset: DataFrame,
|
361
355
|
inference_method: str,
|
362
|
-
) ->
|
363
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
364
|
-
return the available package that exists in the snowflake anaconda channel
|
356
|
+
) -> None:
|
357
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
365
358
|
|
366
359
|
Args:
|
367
360
|
dataset: snowpark dataframe
|
368
361
|
inference_method: the inference method such as predict, score...
|
369
|
-
|
362
|
+
|
370
363
|
Raises:
|
371
364
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
372
365
|
SnowflakeMLException: If the session is None, raise error
|
373
366
|
|
374
|
-
Returns:
|
375
|
-
A list of available package that exists in the snowflake anaconda channel
|
376
367
|
"""
|
377
368
|
if not self._is_fitted:
|
378
369
|
raise exceptions.SnowflakeMLException(
|
@@ -390,9 +381,7 @@ class GaussianMixture(BaseTransformer):
|
|
390
381
|
"Session must not specified for snowpark dataset."
|
391
382
|
),
|
392
383
|
)
|
393
|
-
|
394
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
395
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
384
|
+
|
396
385
|
|
397
386
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
398
387
|
@telemetry.send_api_usage_telemetry(
|
@@ -440,7 +429,8 @@ class GaussianMixture(BaseTransformer):
|
|
440
429
|
|
441
430
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
442
431
|
|
443
|
-
self.
|
432
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
433
|
+
self._deps = self._get_dependencies()
|
444
434
|
assert isinstance(
|
445
435
|
dataset._session, Session
|
446
436
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -523,10 +513,8 @@ class GaussianMixture(BaseTransformer):
|
|
523
513
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
524
514
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
525
515
|
|
526
|
-
self.
|
527
|
-
|
528
|
-
inference_method=inference_method,
|
529
|
-
)
|
516
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
517
|
+
self._deps = self._get_dependencies()
|
530
518
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
531
519
|
|
532
520
|
transform_kwargs = dict(
|
@@ -595,16 +583,40 @@ class GaussianMixture(BaseTransformer):
|
|
595
583
|
self._is_fitted = True
|
596
584
|
return output_result
|
597
585
|
|
586
|
+
|
587
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
588
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
589
|
+
""" Method not supported for this class.
|
590
|
+
|
598
591
|
|
599
|
-
|
600
|
-
|
601
|
-
|
592
|
+
Raises:
|
593
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
594
|
+
|
595
|
+
Args:
|
596
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
597
|
+
Snowpark or Pandas DataFrame.
|
598
|
+
output_cols_prefix: Prefix for the response columns
|
602
599
|
Returns:
|
603
600
|
Transformed dataset.
|
604
601
|
"""
|
605
|
-
self.
|
606
|
-
|
607
|
-
|
602
|
+
self._infer_input_output_cols(dataset)
|
603
|
+
super()._check_dataset_type(dataset)
|
604
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
605
|
+
estimator=self._sklearn_object,
|
606
|
+
dataset=dataset,
|
607
|
+
input_cols=self.input_cols,
|
608
|
+
label_cols=self.label_cols,
|
609
|
+
sample_weight_col=self.sample_weight_col,
|
610
|
+
autogenerated=self._autogenerated,
|
611
|
+
subproject=_SUBPROJECT,
|
612
|
+
)
|
613
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
614
|
+
drop_input_cols=self._drop_input_cols,
|
615
|
+
expected_output_cols_list=self.output_cols,
|
616
|
+
)
|
617
|
+
self._sklearn_object = fitted_estimator
|
618
|
+
self._is_fitted = True
|
619
|
+
return output_result
|
608
620
|
|
609
621
|
|
610
622
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -697,10 +709,8 @@ class GaussianMixture(BaseTransformer):
|
|
697
709
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
698
710
|
|
699
711
|
if isinstance(dataset, DataFrame):
|
700
|
-
self.
|
701
|
-
|
702
|
-
inference_method=inference_method,
|
703
|
-
)
|
712
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
713
|
+
self._deps = self._get_dependencies()
|
704
714
|
assert isinstance(
|
705
715
|
dataset._session, Session
|
706
716
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -767,10 +777,8 @@ class GaussianMixture(BaseTransformer):
|
|
767
777
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
768
778
|
|
769
779
|
if isinstance(dataset, DataFrame):
|
770
|
-
self.
|
771
|
-
|
772
|
-
inference_method=inference_method,
|
773
|
-
)
|
780
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
781
|
+
self._deps = self._get_dependencies()
|
774
782
|
assert isinstance(
|
775
783
|
dataset._session, Session
|
776
784
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -832,10 +840,8 @@ class GaussianMixture(BaseTransformer):
|
|
832
840
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
833
841
|
|
834
842
|
if isinstance(dataset, DataFrame):
|
835
|
-
self.
|
836
|
-
|
837
|
-
inference_method=inference_method,
|
838
|
-
)
|
843
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
844
|
+
self._deps = self._get_dependencies()
|
839
845
|
assert isinstance(
|
840
846
|
dataset._session, Session
|
841
847
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -903,10 +909,8 @@ class GaussianMixture(BaseTransformer):
|
|
903
909
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
904
910
|
|
905
911
|
if isinstance(dataset, DataFrame):
|
906
|
-
self.
|
907
|
-
|
908
|
-
inference_method=inference_method,
|
909
|
-
)
|
912
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
913
|
+
self._deps = self._get_dependencies()
|
910
914
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
911
915
|
transform_kwargs = dict(
|
912
916
|
session=dataset._session,
|
@@ -970,17 +974,15 @@ class GaussianMixture(BaseTransformer):
|
|
970
974
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
971
975
|
|
972
976
|
if isinstance(dataset, DataFrame):
|
973
|
-
self.
|
974
|
-
|
975
|
-
inference_method="score",
|
976
|
-
)
|
977
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
978
|
+
self._deps = self._get_dependencies()
|
977
979
|
selected_cols = self._get_active_columns()
|
978
980
|
if len(selected_cols) > 0:
|
979
981
|
dataset = dataset.select(selected_cols)
|
980
982
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
981
983
|
transform_kwargs = dict(
|
982
984
|
session=dataset._session,
|
983
|
-
dependencies=
|
985
|
+
dependencies=self._deps,
|
984
986
|
score_sproc_imports=['sklearn'],
|
985
987
|
)
|
986
988
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1045,11 +1047,8 @@ class GaussianMixture(BaseTransformer):
|
|
1045
1047
|
|
1046
1048
|
if isinstance(dataset, DataFrame):
|
1047
1049
|
|
1048
|
-
self.
|
1049
|
-
|
1050
|
-
inference_method=inference_method,
|
1051
|
-
|
1052
|
-
)
|
1050
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1051
|
+
self._deps = self._get_dependencies()
|
1053
1052
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1054
1053
|
transform_kwargs = dict(
|
1055
1054
|
session = dataset._session,
|