snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class AgglomerativeClustering(BaseTransformer):
|
58
70
|
r"""Agglomerative Clustering
|
59
71
|
For more details on this class, see [sklearn.cluster.AgglomerativeClustering]
|
@@ -199,7 +211,9 @@ class AgglomerativeClustering(BaseTransformer):
|
|
199
211
|
self.set_label_cols(label_cols)
|
200
212
|
self.set_passthrough_cols(passthrough_cols)
|
201
213
|
self.set_drop_input_cols(drop_input_cols)
|
202
|
-
self.set_sample_weight_col(sample_weight_col)
|
214
|
+
self.set_sample_weight_col(sample_weight_col)
|
215
|
+
self._use_external_memory_version = False
|
216
|
+
self._batch_size = -1
|
203
217
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
204
218
|
|
205
219
|
self._deps = list(deps)
|
@@ -283,11 +297,6 @@ class AgglomerativeClustering(BaseTransformer):
|
|
283
297
|
if isinstance(dataset, DataFrame):
|
284
298
|
session = dataset._session
|
285
299
|
assert session is not None # keep mypy happy
|
286
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
287
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
288
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
289
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
290
|
-
|
291
300
|
# Specify input columns so column pruning will be enforced
|
292
301
|
selected_cols = self._get_active_columns()
|
293
302
|
if len(selected_cols) > 0:
|
@@ -315,7 +324,9 @@ class AgglomerativeClustering(BaseTransformer):
|
|
315
324
|
label_cols=self.label_cols,
|
316
325
|
sample_weight_col=self.sample_weight_col,
|
317
326
|
autogenerated=self._autogenerated,
|
318
|
-
subproject=_SUBPROJECT
|
327
|
+
subproject=_SUBPROJECT,
|
328
|
+
use_external_memory_version=self._use_external_memory_version,
|
329
|
+
batch_size=self._batch_size,
|
319
330
|
)
|
320
331
|
self._sklearn_object = model_trainer.train()
|
321
332
|
self._is_fitted = True
|
@@ -584,6 +595,22 @@ class AgglomerativeClustering(BaseTransformer):
|
|
584
595
|
# each row containing a list of values.
|
585
596
|
expected_dtype = "ARRAY"
|
586
597
|
|
598
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
599
|
+
if expected_dtype == "":
|
600
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
601
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
602
|
+
expected_dtype = "ARRAY"
|
603
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
604
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
605
|
+
expected_dtype = "ARRAY"
|
606
|
+
else:
|
607
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
608
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
609
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
610
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
611
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
612
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
613
|
+
|
587
614
|
output_df = self._batch_inference(
|
588
615
|
dataset=dataset,
|
589
616
|
inference_method="transform",
|
@@ -599,8 +626,8 @@ class AgglomerativeClustering(BaseTransformer):
|
|
599
626
|
|
600
627
|
return output_df
|
601
628
|
|
602
|
-
@available_if(
|
603
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
629
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
630
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
604
631
|
""" Fit and return the result of each sample's clustering assignment
|
605
632
|
For more details on this function, see [sklearn.cluster.AgglomerativeClustering.fit_predict]
|
606
633
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html#sklearn.cluster.AgglomerativeClustering.fit_predict)
|
@@ -615,13 +642,21 @@ class AgglomerativeClustering(BaseTransformer):
|
|
615
642
|
Returns:
|
616
643
|
Predicted dataset.
|
617
644
|
"""
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
645
|
+
self.fit(dataset)
|
646
|
+
assert self._sklearn_object is not None
|
647
|
+
return self._sklearn_object.labels_
|
648
|
+
|
649
|
+
|
650
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
651
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
652
|
+
"""
|
653
|
+
Returns:
|
654
|
+
Transformed dataset.
|
655
|
+
"""
|
656
|
+
self.fit(dataset)
|
657
|
+
assert self._sklearn_object is not None
|
658
|
+
return self._sklearn_object.embedding_
|
659
|
+
|
625
660
|
|
626
661
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
627
662
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class Birch(BaseTransformer):
|
58
70
|
r"""Implements the BIRCH clustering algorithm
|
59
71
|
For more details on this class, see [sklearn.cluster.Birch]
|
@@ -161,7 +173,9 @@ class Birch(BaseTransformer):
|
|
161
173
|
self.set_label_cols(label_cols)
|
162
174
|
self.set_passthrough_cols(passthrough_cols)
|
163
175
|
self.set_drop_input_cols(drop_input_cols)
|
164
|
-
self.set_sample_weight_col(sample_weight_col)
|
176
|
+
self.set_sample_weight_col(sample_weight_col)
|
177
|
+
self._use_external_memory_version = False
|
178
|
+
self._batch_size = -1
|
165
179
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
166
180
|
|
167
181
|
self._deps = list(deps)
|
@@ -241,11 +255,6 @@ class Birch(BaseTransformer):
|
|
241
255
|
if isinstance(dataset, DataFrame):
|
242
256
|
session = dataset._session
|
243
257
|
assert session is not None # keep mypy happy
|
244
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
245
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
246
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
247
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
248
|
-
|
249
258
|
# Specify input columns so column pruning will be enforced
|
250
259
|
selected_cols = self._get_active_columns()
|
251
260
|
if len(selected_cols) > 0:
|
@@ -273,7 +282,9 @@ class Birch(BaseTransformer):
|
|
273
282
|
label_cols=self.label_cols,
|
274
283
|
sample_weight_col=self.sample_weight_col,
|
275
284
|
autogenerated=self._autogenerated,
|
276
|
-
subproject=_SUBPROJECT
|
285
|
+
subproject=_SUBPROJECT,
|
286
|
+
use_external_memory_version=self._use_external_memory_version,
|
287
|
+
batch_size=self._batch_size,
|
277
288
|
)
|
278
289
|
self._sklearn_object = model_trainer.train()
|
279
290
|
self._is_fitted = True
|
@@ -546,6 +557,22 @@ class Birch(BaseTransformer):
|
|
546
557
|
# each row containing a list of values.
|
547
558
|
expected_dtype = "ARRAY"
|
548
559
|
|
560
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
561
|
+
if expected_dtype == "":
|
562
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
563
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
564
|
+
expected_dtype = "ARRAY"
|
565
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
566
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
567
|
+
expected_dtype = "ARRAY"
|
568
|
+
else:
|
569
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
570
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
571
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
572
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
573
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
574
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
575
|
+
|
549
576
|
output_df = self._batch_inference(
|
550
577
|
dataset=dataset,
|
551
578
|
inference_method="transform",
|
@@ -561,8 +588,8 @@ class Birch(BaseTransformer):
|
|
561
588
|
|
562
589
|
return output_df
|
563
590
|
|
564
|
-
@available_if(
|
565
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
591
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
592
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
566
593
|
""" Perform clustering on `X` and returns cluster labels
|
567
594
|
For more details on this function, see [sklearn.cluster.Birch.fit_predict]
|
568
595
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch.fit_predict)
|
@@ -577,13 +604,21 @@ class Birch(BaseTransformer):
|
|
577
604
|
Returns:
|
578
605
|
Predicted dataset.
|
579
606
|
"""
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
607
|
+
self.fit(dataset)
|
608
|
+
assert self._sklearn_object is not None
|
609
|
+
return self._sklearn_object.labels_
|
610
|
+
|
611
|
+
|
612
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
613
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
614
|
+
"""
|
615
|
+
Returns:
|
616
|
+
Transformed dataset.
|
617
|
+
"""
|
618
|
+
self.fit(dataset)
|
619
|
+
assert self._sklearn_object is not None
|
620
|
+
return self._sklearn_object.embedding_
|
621
|
+
|
587
622
|
|
588
623
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
589
624
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class BisectingKMeans(BaseTransformer):
|
58
70
|
r"""Bisecting K-Means clustering
|
59
71
|
For more details on this class, see [sklearn.cluster.BisectingKMeans]
|
@@ -205,7 +217,9 @@ class BisectingKMeans(BaseTransformer):
|
|
205
217
|
self.set_label_cols(label_cols)
|
206
218
|
self.set_passthrough_cols(passthrough_cols)
|
207
219
|
self.set_drop_input_cols(drop_input_cols)
|
208
|
-
self.set_sample_weight_col(sample_weight_col)
|
220
|
+
self.set_sample_weight_col(sample_weight_col)
|
221
|
+
self._use_external_memory_version = False
|
222
|
+
self._batch_size = -1
|
209
223
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
210
224
|
|
211
225
|
self._deps = list(deps)
|
@@ -290,11 +304,6 @@ class BisectingKMeans(BaseTransformer):
|
|
290
304
|
if isinstance(dataset, DataFrame):
|
291
305
|
session = dataset._session
|
292
306
|
assert session is not None # keep mypy happy
|
293
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
294
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
295
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
296
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
297
|
-
|
298
307
|
# Specify input columns so column pruning will be enforced
|
299
308
|
selected_cols = self._get_active_columns()
|
300
309
|
if len(selected_cols) > 0:
|
@@ -322,7 +331,9 @@ class BisectingKMeans(BaseTransformer):
|
|
322
331
|
label_cols=self.label_cols,
|
323
332
|
sample_weight_col=self.sample_weight_col,
|
324
333
|
autogenerated=self._autogenerated,
|
325
|
-
subproject=_SUBPROJECT
|
334
|
+
subproject=_SUBPROJECT,
|
335
|
+
use_external_memory_version=self._use_external_memory_version,
|
336
|
+
batch_size=self._batch_size,
|
326
337
|
)
|
327
338
|
self._sklearn_object = model_trainer.train()
|
328
339
|
self._is_fitted = True
|
@@ -595,6 +606,22 @@ class BisectingKMeans(BaseTransformer):
|
|
595
606
|
# each row containing a list of values.
|
596
607
|
expected_dtype = "ARRAY"
|
597
608
|
|
609
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
610
|
+
if expected_dtype == "":
|
611
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
612
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
613
|
+
expected_dtype = "ARRAY"
|
614
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
615
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
616
|
+
expected_dtype = "ARRAY"
|
617
|
+
else:
|
618
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
619
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
620
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
621
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
622
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
623
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
624
|
+
|
598
625
|
output_df = self._batch_inference(
|
599
626
|
dataset=dataset,
|
600
627
|
inference_method="transform",
|
@@ -610,8 +637,8 @@ class BisectingKMeans(BaseTransformer):
|
|
610
637
|
|
611
638
|
return output_df
|
612
639
|
|
613
|
-
@available_if(
|
614
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
640
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
641
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
615
642
|
""" Compute cluster centers and predict cluster index for each sample
|
616
643
|
For more details on this function, see [sklearn.cluster.BisectingKMeans.fit_predict]
|
617
644
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.BisectingKMeans.html#sklearn.cluster.BisectingKMeans.fit_predict)
|
@@ -626,13 +653,21 @@ class BisectingKMeans(BaseTransformer):
|
|
626
653
|
Returns:
|
627
654
|
Predicted dataset.
|
628
655
|
"""
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
656
|
+
self.fit(dataset)
|
657
|
+
assert self._sklearn_object is not None
|
658
|
+
return self._sklearn_object.labels_
|
659
|
+
|
660
|
+
|
661
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
662
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
663
|
+
"""
|
664
|
+
Returns:
|
665
|
+
Transformed dataset.
|
666
|
+
"""
|
667
|
+
self.fit(dataset)
|
668
|
+
assert self._sklearn_object is not None
|
669
|
+
return self._sklearn_object.embedding_
|
670
|
+
|
636
671
|
|
637
672
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
638
673
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class DBSCAN(BaseTransformer):
|
58
70
|
r"""Perform DBSCAN clustering from vector array or distance matrix
|
59
71
|
For more details on this class, see [sklearn.cluster.DBSCAN]
|
@@ -175,7 +187,9 @@ class DBSCAN(BaseTransformer):
|
|
175
187
|
self.set_label_cols(label_cols)
|
176
188
|
self.set_passthrough_cols(passthrough_cols)
|
177
189
|
self.set_drop_input_cols(drop_input_cols)
|
178
|
-
self.set_sample_weight_col(sample_weight_col)
|
190
|
+
self.set_sample_weight_col(sample_weight_col)
|
191
|
+
self._use_external_memory_version = False
|
192
|
+
self._batch_size = -1
|
179
193
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
180
194
|
|
181
195
|
self._deps = list(deps)
|
@@ -258,11 +272,6 @@ class DBSCAN(BaseTransformer):
|
|
258
272
|
if isinstance(dataset, DataFrame):
|
259
273
|
session = dataset._session
|
260
274
|
assert session is not None # keep mypy happy
|
261
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
262
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
263
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
264
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
265
|
-
|
266
275
|
# Specify input columns so column pruning will be enforced
|
267
276
|
selected_cols = self._get_active_columns()
|
268
277
|
if len(selected_cols) > 0:
|
@@ -290,7 +299,9 @@ class DBSCAN(BaseTransformer):
|
|
290
299
|
label_cols=self.label_cols,
|
291
300
|
sample_weight_col=self.sample_weight_col,
|
292
301
|
autogenerated=self._autogenerated,
|
293
|
-
subproject=_SUBPROJECT
|
302
|
+
subproject=_SUBPROJECT,
|
303
|
+
use_external_memory_version=self._use_external_memory_version,
|
304
|
+
batch_size=self._batch_size,
|
294
305
|
)
|
295
306
|
self._sklearn_object = model_trainer.train()
|
296
307
|
self._is_fitted = True
|
@@ -559,6 +570,22 @@ class DBSCAN(BaseTransformer):
|
|
559
570
|
# each row containing a list of values.
|
560
571
|
expected_dtype = "ARRAY"
|
561
572
|
|
573
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
574
|
+
if expected_dtype == "":
|
575
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
576
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
577
|
+
expected_dtype = "ARRAY"
|
578
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
579
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
580
|
+
expected_dtype = "ARRAY"
|
581
|
+
else:
|
582
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
583
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
584
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
585
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
586
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
587
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
588
|
+
|
562
589
|
output_df = self._batch_inference(
|
563
590
|
dataset=dataset,
|
564
591
|
inference_method="transform",
|
@@ -574,8 +601,8 @@ class DBSCAN(BaseTransformer):
|
|
574
601
|
|
575
602
|
return output_df
|
576
603
|
|
577
|
-
@available_if(
|
578
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
604
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
605
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
579
606
|
""" Compute clusters from a data or distance matrix and predict labels
|
580
607
|
For more details on this function, see [sklearn.cluster.DBSCAN.fit_predict]
|
581
608
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html#sklearn.cluster.DBSCAN.fit_predict)
|
@@ -590,13 +617,21 @@ class DBSCAN(BaseTransformer):
|
|
590
617
|
Returns:
|
591
618
|
Predicted dataset.
|
592
619
|
"""
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
620
|
+
self.fit(dataset)
|
621
|
+
assert self._sklearn_object is not None
|
622
|
+
return self._sklearn_object.labels_
|
623
|
+
|
624
|
+
|
625
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
626
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
627
|
+
"""
|
628
|
+
Returns:
|
629
|
+
Transformed dataset.
|
630
|
+
"""
|
631
|
+
self.fit(dataset)
|
632
|
+
assert self._sklearn_object is not None
|
633
|
+
return self._sklearn_object.embedding_
|
634
|
+
|
600
635
|
|
601
636
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
602
637
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class FeatureAgglomeration(BaseTransformer):
|
58
70
|
r"""Agglomerate features
|
59
71
|
For more details on this class, see [sklearn.cluster.FeatureAgglomeration]
|
@@ -205,7 +217,9 @@ class FeatureAgglomeration(BaseTransformer):
|
|
205
217
|
self.set_label_cols(label_cols)
|
206
218
|
self.set_passthrough_cols(passthrough_cols)
|
207
219
|
self.set_drop_input_cols(drop_input_cols)
|
208
|
-
self.set_sample_weight_col(sample_weight_col)
|
220
|
+
self.set_sample_weight_col(sample_weight_col)
|
221
|
+
self._use_external_memory_version = False
|
222
|
+
self._batch_size = -1
|
209
223
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
210
224
|
|
211
225
|
self._deps = list(deps)
|
@@ -290,11 +304,6 @@ class FeatureAgglomeration(BaseTransformer):
|
|
290
304
|
if isinstance(dataset, DataFrame):
|
291
305
|
session = dataset._session
|
292
306
|
assert session is not None # keep mypy happy
|
293
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
294
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
295
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
296
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
297
|
-
|
298
307
|
# Specify input columns so column pruning will be enforced
|
299
308
|
selected_cols = self._get_active_columns()
|
300
309
|
if len(selected_cols) > 0:
|
@@ -322,7 +331,9 @@ class FeatureAgglomeration(BaseTransformer):
|
|
322
331
|
label_cols=self.label_cols,
|
323
332
|
sample_weight_col=self.sample_weight_col,
|
324
333
|
autogenerated=self._autogenerated,
|
325
|
-
subproject=_SUBPROJECT
|
334
|
+
subproject=_SUBPROJECT,
|
335
|
+
use_external_memory_version=self._use_external_memory_version,
|
336
|
+
batch_size=self._batch_size,
|
326
337
|
)
|
327
338
|
self._sklearn_object = model_trainer.train()
|
328
339
|
self._is_fitted = True
|
@@ -593,6 +604,22 @@ class FeatureAgglomeration(BaseTransformer):
|
|
593
604
|
# each row containing a list of values.
|
594
605
|
expected_dtype = "ARRAY"
|
595
606
|
|
607
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
608
|
+
if expected_dtype == "":
|
609
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
610
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
611
|
+
expected_dtype = "ARRAY"
|
612
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
613
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
614
|
+
expected_dtype = "ARRAY"
|
615
|
+
else:
|
616
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
617
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
618
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
619
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
620
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
621
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
622
|
+
|
596
623
|
output_df = self._batch_inference(
|
597
624
|
dataset=dataset,
|
598
625
|
inference_method="transform",
|
@@ -608,8 +635,8 @@ class FeatureAgglomeration(BaseTransformer):
|
|
608
635
|
|
609
636
|
return output_df
|
610
637
|
|
611
|
-
@available_if(
|
612
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
638
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
639
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
613
640
|
""" Fit and return the result of each sample's clustering assignment
|
614
641
|
For more details on this function, see [sklearn.cluster.FeatureAgglomeration.fit_predict]
|
615
642
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.FeatureAgglomeration.html#sklearn.cluster.FeatureAgglomeration.fit_predict)
|
@@ -624,13 +651,21 @@ class FeatureAgglomeration(BaseTransformer):
|
|
624
651
|
Returns:
|
625
652
|
Predicted dataset.
|
626
653
|
"""
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
654
|
+
self.fit(dataset)
|
655
|
+
assert self._sklearn_object is not None
|
656
|
+
return self._sklearn_object.labels_
|
657
|
+
|
658
|
+
|
659
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
660
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
661
|
+
"""
|
662
|
+
Returns:
|
663
|
+
Transformed dataset.
|
664
|
+
"""
|
665
|
+
self.fit(dataset)
|
666
|
+
assert self._sklearn_object is not None
|
667
|
+
return self._sklearn_object.embedding_
|
668
|
+
|
634
669
|
|
635
670
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
636
671
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|