snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class KMeans(BaseTransformer):
|
58
70
|
r"""K-Means clustering
|
59
71
|
For more details on this class, see [sklearn.cluster.KMeans]
|
@@ -201,7 +213,9 @@ class KMeans(BaseTransformer):
|
|
201
213
|
self.set_label_cols(label_cols)
|
202
214
|
self.set_passthrough_cols(passthrough_cols)
|
203
215
|
self.set_drop_input_cols(drop_input_cols)
|
204
|
-
self.set_sample_weight_col(sample_weight_col)
|
216
|
+
self.set_sample_weight_col(sample_weight_col)
|
217
|
+
self._use_external_memory_version = False
|
218
|
+
self._batch_size = -1
|
205
219
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
206
220
|
|
207
221
|
self._deps = list(deps)
|
@@ -285,11 +299,6 @@ class KMeans(BaseTransformer):
|
|
285
299
|
if isinstance(dataset, DataFrame):
|
286
300
|
session = dataset._session
|
287
301
|
assert session is not None # keep mypy happy
|
288
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
289
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
290
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
291
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
292
|
-
|
293
302
|
# Specify input columns so column pruning will be enforced
|
294
303
|
selected_cols = self._get_active_columns()
|
295
304
|
if len(selected_cols) > 0:
|
@@ -317,7 +326,9 @@ class KMeans(BaseTransformer):
|
|
317
326
|
label_cols=self.label_cols,
|
318
327
|
sample_weight_col=self.sample_weight_col,
|
319
328
|
autogenerated=self._autogenerated,
|
320
|
-
subproject=_SUBPROJECT
|
329
|
+
subproject=_SUBPROJECT,
|
330
|
+
use_external_memory_version=self._use_external_memory_version,
|
331
|
+
batch_size=self._batch_size,
|
321
332
|
)
|
322
333
|
self._sklearn_object = model_trainer.train()
|
323
334
|
self._is_fitted = True
|
@@ -590,6 +601,22 @@ class KMeans(BaseTransformer):
|
|
590
601
|
# each row containing a list of values.
|
591
602
|
expected_dtype = "ARRAY"
|
592
603
|
|
604
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
605
|
+
if expected_dtype == "":
|
606
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
607
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
608
|
+
expected_dtype = "ARRAY"
|
609
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
610
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
611
|
+
expected_dtype = "ARRAY"
|
612
|
+
else:
|
613
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
614
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
615
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
616
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
617
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
618
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
619
|
+
|
593
620
|
output_df = self._batch_inference(
|
594
621
|
dataset=dataset,
|
595
622
|
inference_method="transform",
|
@@ -605,8 +632,8 @@ class KMeans(BaseTransformer):
|
|
605
632
|
|
606
633
|
return output_df
|
607
634
|
|
608
|
-
@available_if(
|
609
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
635
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
636
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
610
637
|
""" Compute cluster centers and predict cluster index for each sample
|
611
638
|
For more details on this function, see [sklearn.cluster.KMeans.fit_predict]
|
612
639
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit_predict)
|
@@ -621,13 +648,21 @@ class KMeans(BaseTransformer):
|
|
621
648
|
Returns:
|
622
649
|
Predicted dataset.
|
623
650
|
"""
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
651
|
+
self.fit(dataset)
|
652
|
+
assert self._sklearn_object is not None
|
653
|
+
return self._sklearn_object.labels_
|
654
|
+
|
655
|
+
|
656
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
657
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
658
|
+
"""
|
659
|
+
Returns:
|
660
|
+
Transformed dataset.
|
661
|
+
"""
|
662
|
+
self.fit(dataset)
|
663
|
+
assert self._sklearn_object is not None
|
664
|
+
return self._sklearn_object.embedding_
|
665
|
+
|
631
666
|
|
632
667
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
633
668
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MeanShift(BaseTransformer):
|
58
70
|
r"""Mean shift clustering using a flat kernel
|
59
71
|
For more details on this class, see [sklearn.cluster.MeanShift]
|
@@ -179,7 +191,9 @@ class MeanShift(BaseTransformer):
|
|
179
191
|
self.set_label_cols(label_cols)
|
180
192
|
self.set_passthrough_cols(passthrough_cols)
|
181
193
|
self.set_drop_input_cols(drop_input_cols)
|
182
|
-
self.set_sample_weight_col(sample_weight_col)
|
194
|
+
self.set_sample_weight_col(sample_weight_col)
|
195
|
+
self._use_external_memory_version = False
|
196
|
+
self._batch_size = -1
|
183
197
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
184
198
|
|
185
199
|
self._deps = list(deps)
|
@@ -261,11 +275,6 @@ class MeanShift(BaseTransformer):
|
|
261
275
|
if isinstance(dataset, DataFrame):
|
262
276
|
session = dataset._session
|
263
277
|
assert session is not None # keep mypy happy
|
264
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
265
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
266
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
267
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
268
|
-
|
269
278
|
# Specify input columns so column pruning will be enforced
|
270
279
|
selected_cols = self._get_active_columns()
|
271
280
|
if len(selected_cols) > 0:
|
@@ -293,7 +302,9 @@ class MeanShift(BaseTransformer):
|
|
293
302
|
label_cols=self.label_cols,
|
294
303
|
sample_weight_col=self.sample_weight_col,
|
295
304
|
autogenerated=self._autogenerated,
|
296
|
-
subproject=_SUBPROJECT
|
305
|
+
subproject=_SUBPROJECT,
|
306
|
+
use_external_memory_version=self._use_external_memory_version,
|
307
|
+
batch_size=self._batch_size,
|
297
308
|
)
|
298
309
|
self._sklearn_object = model_trainer.train()
|
299
310
|
self._is_fitted = True
|
@@ -564,6 +575,22 @@ class MeanShift(BaseTransformer):
|
|
564
575
|
# each row containing a list of values.
|
565
576
|
expected_dtype = "ARRAY"
|
566
577
|
|
578
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
579
|
+
if expected_dtype == "":
|
580
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
581
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
582
|
+
expected_dtype = "ARRAY"
|
583
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
584
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
585
|
+
expected_dtype = "ARRAY"
|
586
|
+
else:
|
587
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
588
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
589
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
590
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
591
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
592
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
593
|
+
|
567
594
|
output_df = self._batch_inference(
|
568
595
|
dataset=dataset,
|
569
596
|
inference_method="transform",
|
@@ -579,8 +606,8 @@ class MeanShift(BaseTransformer):
|
|
579
606
|
|
580
607
|
return output_df
|
581
608
|
|
582
|
-
@available_if(
|
583
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
609
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
610
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
584
611
|
""" Perform clustering on `X` and returns cluster labels
|
585
612
|
For more details on this function, see [sklearn.cluster.MeanShift.fit_predict]
|
586
613
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html#sklearn.cluster.MeanShift.fit_predict)
|
@@ -595,13 +622,21 @@ class MeanShift(BaseTransformer):
|
|
595
622
|
Returns:
|
596
623
|
Predicted dataset.
|
597
624
|
"""
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
625
|
+
self.fit(dataset)
|
626
|
+
assert self._sklearn_object is not None
|
627
|
+
return self._sklearn_object.labels_
|
628
|
+
|
629
|
+
|
630
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
631
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
632
|
+
"""
|
633
|
+
Returns:
|
634
|
+
Transformed dataset.
|
635
|
+
"""
|
636
|
+
self.fit(dataset)
|
637
|
+
assert self._sklearn_object is not None
|
638
|
+
return self._sklearn_object.embedding_
|
639
|
+
|
605
640
|
|
606
641
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
607
642
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MiniBatchKMeans(BaseTransformer):
|
58
70
|
r"""Mini-Batch K-Means clustering
|
59
71
|
For more details on this class, see [sklearn.cluster.MiniBatchKMeans]
|
@@ -224,7 +236,9 @@ class MiniBatchKMeans(BaseTransformer):
|
|
224
236
|
self.set_label_cols(label_cols)
|
225
237
|
self.set_passthrough_cols(passthrough_cols)
|
226
238
|
self.set_drop_input_cols(drop_input_cols)
|
227
|
-
self.set_sample_weight_col(sample_weight_col)
|
239
|
+
self.set_sample_weight_col(sample_weight_col)
|
240
|
+
self._use_external_memory_version = False
|
241
|
+
self._batch_size = -1
|
228
242
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
229
243
|
|
230
244
|
self._deps = list(deps)
|
@@ -311,11 +325,6 @@ class MiniBatchKMeans(BaseTransformer):
|
|
311
325
|
if isinstance(dataset, DataFrame):
|
312
326
|
session = dataset._session
|
313
327
|
assert session is not None # keep mypy happy
|
314
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
315
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
316
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
317
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
318
|
-
|
319
328
|
# Specify input columns so column pruning will be enforced
|
320
329
|
selected_cols = self._get_active_columns()
|
321
330
|
if len(selected_cols) > 0:
|
@@ -343,7 +352,9 @@ class MiniBatchKMeans(BaseTransformer):
|
|
343
352
|
label_cols=self.label_cols,
|
344
353
|
sample_weight_col=self.sample_weight_col,
|
345
354
|
autogenerated=self._autogenerated,
|
346
|
-
subproject=_SUBPROJECT
|
355
|
+
subproject=_SUBPROJECT,
|
356
|
+
use_external_memory_version=self._use_external_memory_version,
|
357
|
+
batch_size=self._batch_size,
|
347
358
|
)
|
348
359
|
self._sklearn_object = model_trainer.train()
|
349
360
|
self._is_fitted = True
|
@@ -616,6 +627,22 @@ class MiniBatchKMeans(BaseTransformer):
|
|
616
627
|
# each row containing a list of values.
|
617
628
|
expected_dtype = "ARRAY"
|
618
629
|
|
630
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
631
|
+
if expected_dtype == "":
|
632
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
633
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
634
|
+
expected_dtype = "ARRAY"
|
635
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
636
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
637
|
+
expected_dtype = "ARRAY"
|
638
|
+
else:
|
639
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
640
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
641
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
642
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
643
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
644
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
645
|
+
|
619
646
|
output_df = self._batch_inference(
|
620
647
|
dataset=dataset,
|
621
648
|
inference_method="transform",
|
@@ -631,8 +658,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
631
658
|
|
632
659
|
return output_df
|
633
660
|
|
634
|
-
@available_if(
|
635
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
661
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
662
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
636
663
|
""" Compute cluster centers and predict cluster index for each sample
|
637
664
|
For more details on this function, see [sklearn.cluster.MiniBatchKMeans.fit_predict]
|
638
665
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html#sklearn.cluster.MiniBatchKMeans.fit_predict)
|
@@ -647,13 +674,21 @@ class MiniBatchKMeans(BaseTransformer):
|
|
647
674
|
Returns:
|
648
675
|
Predicted dataset.
|
649
676
|
"""
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
677
|
+
self.fit(dataset)
|
678
|
+
assert self._sklearn_object is not None
|
679
|
+
return self._sklearn_object.labels_
|
680
|
+
|
681
|
+
|
682
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
683
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
684
|
+
"""
|
685
|
+
Returns:
|
686
|
+
Transformed dataset.
|
687
|
+
"""
|
688
|
+
self.fit(dataset)
|
689
|
+
assert self._sklearn_object is not None
|
690
|
+
return self._sklearn_object.embedding_
|
691
|
+
|
657
692
|
|
658
693
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
659
694
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class OPTICS(BaseTransformer):
|
58
70
|
r"""Estimate clustering structure from vector array
|
59
71
|
For more details on this class, see [sklearn.cluster.OPTICS]
|
@@ -242,7 +254,9 @@ class OPTICS(BaseTransformer):
|
|
242
254
|
self.set_label_cols(label_cols)
|
243
255
|
self.set_passthrough_cols(passthrough_cols)
|
244
256
|
self.set_drop_input_cols(drop_input_cols)
|
245
|
-
self.set_sample_weight_col(sample_weight_col)
|
257
|
+
self.set_sample_weight_col(sample_weight_col)
|
258
|
+
self._use_external_memory_version = False
|
259
|
+
self._batch_size = -1
|
246
260
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
247
261
|
|
248
262
|
self._deps = list(deps)
|
@@ -331,11 +345,6 @@ class OPTICS(BaseTransformer):
|
|
331
345
|
if isinstance(dataset, DataFrame):
|
332
346
|
session = dataset._session
|
333
347
|
assert session is not None # keep mypy happy
|
334
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
335
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
336
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
337
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
338
|
-
|
339
348
|
# Specify input columns so column pruning will be enforced
|
340
349
|
selected_cols = self._get_active_columns()
|
341
350
|
if len(selected_cols) > 0:
|
@@ -363,7 +372,9 @@ class OPTICS(BaseTransformer):
|
|
363
372
|
label_cols=self.label_cols,
|
364
373
|
sample_weight_col=self.sample_weight_col,
|
365
374
|
autogenerated=self._autogenerated,
|
366
|
-
subproject=_SUBPROJECT
|
375
|
+
subproject=_SUBPROJECT,
|
376
|
+
use_external_memory_version=self._use_external_memory_version,
|
377
|
+
batch_size=self._batch_size,
|
367
378
|
)
|
368
379
|
self._sklearn_object = model_trainer.train()
|
369
380
|
self._is_fitted = True
|
@@ -632,6 +643,22 @@ class OPTICS(BaseTransformer):
|
|
632
643
|
# each row containing a list of values.
|
633
644
|
expected_dtype = "ARRAY"
|
634
645
|
|
646
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
647
|
+
if expected_dtype == "":
|
648
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
649
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
650
|
+
expected_dtype = "ARRAY"
|
651
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
652
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
653
|
+
expected_dtype = "ARRAY"
|
654
|
+
else:
|
655
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
656
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
657
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
658
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
659
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
660
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
661
|
+
|
635
662
|
output_df = self._batch_inference(
|
636
663
|
dataset=dataset,
|
637
664
|
inference_method="transform",
|
@@ -647,8 +674,8 @@ class OPTICS(BaseTransformer):
|
|
647
674
|
|
648
675
|
return output_df
|
649
676
|
|
650
|
-
@available_if(
|
651
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
677
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
678
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
652
679
|
""" Perform clustering on `X` and returns cluster labels
|
653
680
|
For more details on this function, see [sklearn.cluster.OPTICS.fit_predict]
|
654
681
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html#sklearn.cluster.OPTICS.fit_predict)
|
@@ -663,13 +690,21 @@ class OPTICS(BaseTransformer):
|
|
663
690
|
Returns:
|
664
691
|
Predicted dataset.
|
665
692
|
"""
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
693
|
+
self.fit(dataset)
|
694
|
+
assert self._sklearn_object is not None
|
695
|
+
return self._sklearn_object.labels_
|
696
|
+
|
697
|
+
|
698
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
699
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
700
|
+
"""
|
701
|
+
Returns:
|
702
|
+
Transformed dataset.
|
703
|
+
"""
|
704
|
+
self.fit(dataset)
|
705
|
+
assert self._sklearn_object is not None
|
706
|
+
return self._sklearn_object.embedding_
|
707
|
+
|
673
708
|
|
674
709
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
675
710
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SpectralBiclustering(BaseTransformer):
|
58
70
|
r"""Spectral biclustering (Kluger, 2003)
|
59
71
|
For more details on this class, see [sklearn.cluster.SpectralBiclustering]
|
@@ -184,7 +196,9 @@ class SpectralBiclustering(BaseTransformer):
|
|
184
196
|
self.set_label_cols(label_cols)
|
185
197
|
self.set_passthrough_cols(passthrough_cols)
|
186
198
|
self.set_drop_input_cols(drop_input_cols)
|
187
|
-
self.set_sample_weight_col(sample_weight_col)
|
199
|
+
self.set_sample_weight_col(sample_weight_col)
|
200
|
+
self._use_external_memory_version = False
|
201
|
+
self._batch_size = -1
|
188
202
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
189
203
|
|
190
204
|
self._deps = list(deps)
|
@@ -269,11 +283,6 @@ class SpectralBiclustering(BaseTransformer):
|
|
269
283
|
if isinstance(dataset, DataFrame):
|
270
284
|
session = dataset._session
|
271
285
|
assert session is not None # keep mypy happy
|
272
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
273
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
274
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
275
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
276
|
-
|
277
286
|
# Specify input columns so column pruning will be enforced
|
278
287
|
selected_cols = self._get_active_columns()
|
279
288
|
if len(selected_cols) > 0:
|
@@ -301,7 +310,9 @@ class SpectralBiclustering(BaseTransformer):
|
|
301
310
|
label_cols=self.label_cols,
|
302
311
|
sample_weight_col=self.sample_weight_col,
|
303
312
|
autogenerated=self._autogenerated,
|
304
|
-
subproject=_SUBPROJECT
|
313
|
+
subproject=_SUBPROJECT,
|
314
|
+
use_external_memory_version=self._use_external_memory_version,
|
315
|
+
batch_size=self._batch_size,
|
305
316
|
)
|
306
317
|
self._sklearn_object = model_trainer.train()
|
307
318
|
self._is_fitted = True
|
@@ -570,6 +581,22 @@ class SpectralBiclustering(BaseTransformer):
|
|
570
581
|
# each row containing a list of values.
|
571
582
|
expected_dtype = "ARRAY"
|
572
583
|
|
584
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
585
|
+
if expected_dtype == "":
|
586
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
587
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
588
|
+
expected_dtype = "ARRAY"
|
589
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
590
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
591
|
+
expected_dtype = "ARRAY"
|
592
|
+
else:
|
593
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
594
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
595
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
596
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
597
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
598
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
599
|
+
|
573
600
|
output_df = self._batch_inference(
|
574
601
|
dataset=dataset,
|
575
602
|
inference_method="transform",
|
@@ -585,8 +612,8 @@ class SpectralBiclustering(BaseTransformer):
|
|
585
612
|
|
586
613
|
return output_df
|
587
614
|
|
588
|
-
@available_if(
|
589
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
615
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
616
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
590
617
|
""" Method not supported for this class.
|
591
618
|
|
592
619
|
|
@@ -599,13 +626,21 @@ class SpectralBiclustering(BaseTransformer):
|
|
599
626
|
Returns:
|
600
627
|
Predicted dataset.
|
601
628
|
"""
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
629
|
+
self.fit(dataset)
|
630
|
+
assert self._sklearn_object is not None
|
631
|
+
return self._sklearn_object.labels_
|
632
|
+
|
633
|
+
|
634
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
635
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
636
|
+
"""
|
637
|
+
Returns:
|
638
|
+
Transformed dataset.
|
639
|
+
"""
|
640
|
+
self.fit(dataset)
|
641
|
+
assert self._sklearn_object is not None
|
642
|
+
return self._sklearn_object.embedding_
|
643
|
+
|
609
644
|
|
610
645
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
611
646
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|