snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class KNeighborsClassifier(BaseTransformer):
|
58
70
|
r"""Classifier implementing the k-nearest neighbors vote
|
59
71
|
For more details on this class, see [sklearn.neighbors.KNeighborsClassifier]
|
@@ -198,7 +210,9 @@ class KNeighborsClassifier(BaseTransformer):
|
|
198
210
|
self.set_label_cols(label_cols)
|
199
211
|
self.set_passthrough_cols(passthrough_cols)
|
200
212
|
self.set_drop_input_cols(drop_input_cols)
|
201
|
-
self.set_sample_weight_col(sample_weight_col)
|
213
|
+
self.set_sample_weight_col(sample_weight_col)
|
214
|
+
self._use_external_memory_version = False
|
215
|
+
self._batch_size = -1
|
202
216
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
203
217
|
|
204
218
|
self._deps = list(deps)
|
@@ -281,11 +295,6 @@ class KNeighborsClassifier(BaseTransformer):
|
|
281
295
|
if isinstance(dataset, DataFrame):
|
282
296
|
session = dataset._session
|
283
297
|
assert session is not None # keep mypy happy
|
284
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
285
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
286
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
287
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
288
|
-
|
289
298
|
# Specify input columns so column pruning will be enforced
|
290
299
|
selected_cols = self._get_active_columns()
|
291
300
|
if len(selected_cols) > 0:
|
@@ -313,7 +322,9 @@ class KNeighborsClassifier(BaseTransformer):
|
|
313
322
|
label_cols=self.label_cols,
|
314
323
|
sample_weight_col=self.sample_weight_col,
|
315
324
|
autogenerated=self._autogenerated,
|
316
|
-
subproject=_SUBPROJECT
|
325
|
+
subproject=_SUBPROJECT,
|
326
|
+
use_external_memory_version=self._use_external_memory_version,
|
327
|
+
batch_size=self._batch_size,
|
317
328
|
)
|
318
329
|
self._sklearn_object = model_trainer.train()
|
319
330
|
self._is_fitted = True
|
@@ -584,6 +595,22 @@ class KNeighborsClassifier(BaseTransformer):
|
|
584
595
|
# each row containing a list of values.
|
585
596
|
expected_dtype = "ARRAY"
|
586
597
|
|
598
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
599
|
+
if expected_dtype == "":
|
600
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
601
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
602
|
+
expected_dtype = "ARRAY"
|
603
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
604
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
605
|
+
expected_dtype = "ARRAY"
|
606
|
+
else:
|
607
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
608
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
609
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
610
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
611
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
612
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
613
|
+
|
587
614
|
output_df = self._batch_inference(
|
588
615
|
dataset=dataset,
|
589
616
|
inference_method="transform",
|
@@ -599,8 +626,8 @@ class KNeighborsClassifier(BaseTransformer):
|
|
599
626
|
|
600
627
|
return output_df
|
601
628
|
|
602
|
-
@available_if(
|
603
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
629
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
630
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
604
631
|
""" Method not supported for this class.
|
605
632
|
|
606
633
|
|
@@ -613,13 +640,21 @@ class KNeighborsClassifier(BaseTransformer):
|
|
613
640
|
Returns:
|
614
641
|
Predicted dataset.
|
615
642
|
"""
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
643
|
+
self.fit(dataset)
|
644
|
+
assert self._sklearn_object is not None
|
645
|
+
return self._sklearn_object.labels_
|
646
|
+
|
647
|
+
|
648
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
649
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
650
|
+
"""
|
651
|
+
Returns:
|
652
|
+
Transformed dataset.
|
653
|
+
"""
|
654
|
+
self.fit(dataset)
|
655
|
+
assert self._sklearn_object is not None
|
656
|
+
return self._sklearn_object.embedding_
|
657
|
+
|
623
658
|
|
624
659
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
625
660
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class KNeighborsRegressor(BaseTransformer):
|
58
70
|
r"""Regression based on k-nearest neighbors
|
59
71
|
For more details on this class, see [sklearn.neighbors.KNeighborsRegressor]
|
@@ -200,7 +212,9 @@ class KNeighborsRegressor(BaseTransformer):
|
|
200
212
|
self.set_label_cols(label_cols)
|
201
213
|
self.set_passthrough_cols(passthrough_cols)
|
202
214
|
self.set_drop_input_cols(drop_input_cols)
|
203
|
-
self.set_sample_weight_col(sample_weight_col)
|
215
|
+
self.set_sample_weight_col(sample_weight_col)
|
216
|
+
self._use_external_memory_version = False
|
217
|
+
self._batch_size = -1
|
204
218
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
205
219
|
|
206
220
|
self._deps = list(deps)
|
@@ -283,11 +297,6 @@ class KNeighborsRegressor(BaseTransformer):
|
|
283
297
|
if isinstance(dataset, DataFrame):
|
284
298
|
session = dataset._session
|
285
299
|
assert session is not None # keep mypy happy
|
286
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
287
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
288
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
289
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
290
|
-
|
291
300
|
# Specify input columns so column pruning will be enforced
|
292
301
|
selected_cols = self._get_active_columns()
|
293
302
|
if len(selected_cols) > 0:
|
@@ -315,7 +324,9 @@ class KNeighborsRegressor(BaseTransformer):
|
|
315
324
|
label_cols=self.label_cols,
|
316
325
|
sample_weight_col=self.sample_weight_col,
|
317
326
|
autogenerated=self._autogenerated,
|
318
|
-
subproject=_SUBPROJECT
|
327
|
+
subproject=_SUBPROJECT,
|
328
|
+
use_external_memory_version=self._use_external_memory_version,
|
329
|
+
batch_size=self._batch_size,
|
319
330
|
)
|
320
331
|
self._sklearn_object = model_trainer.train()
|
321
332
|
self._is_fitted = True
|
@@ -586,6 +597,22 @@ class KNeighborsRegressor(BaseTransformer):
|
|
586
597
|
# each row containing a list of values.
|
587
598
|
expected_dtype = "ARRAY"
|
588
599
|
|
600
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
601
|
+
if expected_dtype == "":
|
602
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
603
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
604
|
+
expected_dtype = "ARRAY"
|
605
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
606
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
607
|
+
expected_dtype = "ARRAY"
|
608
|
+
else:
|
609
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
610
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
611
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
612
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
613
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
614
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
615
|
+
|
589
616
|
output_df = self._batch_inference(
|
590
617
|
dataset=dataset,
|
591
618
|
inference_method="transform",
|
@@ -601,8 +628,8 @@ class KNeighborsRegressor(BaseTransformer):
|
|
601
628
|
|
602
629
|
return output_df
|
603
630
|
|
604
|
-
@available_if(
|
605
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
631
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
632
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
606
633
|
""" Method not supported for this class.
|
607
634
|
|
608
635
|
|
@@ -615,13 +642,21 @@ class KNeighborsRegressor(BaseTransformer):
|
|
615
642
|
Returns:
|
616
643
|
Predicted dataset.
|
617
644
|
"""
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
645
|
+
self.fit(dataset)
|
646
|
+
assert self._sklearn_object is not None
|
647
|
+
return self._sklearn_object.labels_
|
648
|
+
|
649
|
+
|
650
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
651
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
652
|
+
"""
|
653
|
+
Returns:
|
654
|
+
Transformed dataset.
|
655
|
+
"""
|
656
|
+
self.fit(dataset)
|
657
|
+
assert self._sklearn_object is not None
|
658
|
+
return self._sklearn_object.embedding_
|
659
|
+
|
625
660
|
|
626
661
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
627
662
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class KernelDensity(BaseTransformer):
|
58
70
|
r"""Kernel Density Estimation
|
59
71
|
For more details on this class, see [sklearn.neighbors.KernelDensity]
|
@@ -176,7 +188,9 @@ class KernelDensity(BaseTransformer):
|
|
176
188
|
self.set_label_cols(label_cols)
|
177
189
|
self.set_passthrough_cols(passthrough_cols)
|
178
190
|
self.set_drop_input_cols(drop_input_cols)
|
179
|
-
self.set_sample_weight_col(sample_weight_col)
|
191
|
+
self.set_sample_weight_col(sample_weight_col)
|
192
|
+
self._use_external_memory_version = False
|
193
|
+
self._batch_size = -1
|
180
194
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
181
195
|
|
182
196
|
self._deps = list(deps)
|
@@ -260,11 +274,6 @@ class KernelDensity(BaseTransformer):
|
|
260
274
|
if isinstance(dataset, DataFrame):
|
261
275
|
session = dataset._session
|
262
276
|
assert session is not None # keep mypy happy
|
263
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
264
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
265
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
266
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
267
|
-
|
268
277
|
# Specify input columns so column pruning will be enforced
|
269
278
|
selected_cols = self._get_active_columns()
|
270
279
|
if len(selected_cols) > 0:
|
@@ -292,7 +301,9 @@ class KernelDensity(BaseTransformer):
|
|
292
301
|
label_cols=self.label_cols,
|
293
302
|
sample_weight_col=self.sample_weight_col,
|
294
303
|
autogenerated=self._autogenerated,
|
295
|
-
subproject=_SUBPROJECT
|
304
|
+
subproject=_SUBPROJECT,
|
305
|
+
use_external_memory_version=self._use_external_memory_version,
|
306
|
+
batch_size=self._batch_size,
|
296
307
|
)
|
297
308
|
self._sklearn_object = model_trainer.train()
|
298
309
|
self._is_fitted = True
|
@@ -561,6 +572,22 @@ class KernelDensity(BaseTransformer):
|
|
561
572
|
# each row containing a list of values.
|
562
573
|
expected_dtype = "ARRAY"
|
563
574
|
|
575
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
576
|
+
if expected_dtype == "":
|
577
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
578
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
579
|
+
expected_dtype = "ARRAY"
|
580
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
581
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
582
|
+
expected_dtype = "ARRAY"
|
583
|
+
else:
|
584
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
585
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
586
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
587
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
588
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
589
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
590
|
+
|
564
591
|
output_df = self._batch_inference(
|
565
592
|
dataset=dataset,
|
566
593
|
inference_method="transform",
|
@@ -576,8 +603,8 @@ class KernelDensity(BaseTransformer):
|
|
576
603
|
|
577
604
|
return output_df
|
578
605
|
|
579
|
-
@available_if(
|
580
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
606
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
607
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
581
608
|
""" Method not supported for this class.
|
582
609
|
|
583
610
|
|
@@ -590,13 +617,21 @@ class KernelDensity(BaseTransformer):
|
|
590
617
|
Returns:
|
591
618
|
Predicted dataset.
|
592
619
|
"""
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
620
|
+
self.fit(dataset)
|
621
|
+
assert self._sklearn_object is not None
|
622
|
+
return self._sklearn_object.labels_
|
623
|
+
|
624
|
+
|
625
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
626
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
627
|
+
"""
|
628
|
+
Returns:
|
629
|
+
Transformed dataset.
|
630
|
+
"""
|
631
|
+
self.fit(dataset)
|
632
|
+
assert self._sklearn_object is not None
|
633
|
+
return self._sklearn_object.embedding_
|
634
|
+
|
600
635
|
|
601
636
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
602
637
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LocalOutlierFactor(BaseTransformer):
|
58
70
|
r"""Unsupervised Outlier Detection using the Local Outlier Factor (LOF)
|
59
71
|
For more details on this class, see [sklearn.neighbors.LocalOutlierFactor]
|
@@ -204,7 +216,9 @@ class LocalOutlierFactor(BaseTransformer):
|
|
204
216
|
self.set_label_cols(label_cols)
|
205
217
|
self.set_passthrough_cols(passthrough_cols)
|
206
218
|
self.set_drop_input_cols(drop_input_cols)
|
207
|
-
self.set_sample_weight_col(sample_weight_col)
|
219
|
+
self.set_sample_weight_col(sample_weight_col)
|
220
|
+
self._use_external_memory_version = False
|
221
|
+
self._batch_size = -1
|
208
222
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
209
223
|
|
210
224
|
self._deps = list(deps)
|
@@ -288,11 +302,6 @@ class LocalOutlierFactor(BaseTransformer):
|
|
288
302
|
if isinstance(dataset, DataFrame):
|
289
303
|
session = dataset._session
|
290
304
|
assert session is not None # keep mypy happy
|
291
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
292
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
293
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
294
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
295
|
-
|
296
305
|
# Specify input columns so column pruning will be enforced
|
297
306
|
selected_cols = self._get_active_columns()
|
298
307
|
if len(selected_cols) > 0:
|
@@ -320,7 +329,9 @@ class LocalOutlierFactor(BaseTransformer):
|
|
320
329
|
label_cols=self.label_cols,
|
321
330
|
sample_weight_col=self.sample_weight_col,
|
322
331
|
autogenerated=self._autogenerated,
|
323
|
-
subproject=_SUBPROJECT
|
332
|
+
subproject=_SUBPROJECT,
|
333
|
+
use_external_memory_version=self._use_external_memory_version,
|
334
|
+
batch_size=self._batch_size,
|
324
335
|
)
|
325
336
|
self._sklearn_object = model_trainer.train()
|
326
337
|
self._is_fitted = True
|
@@ -591,6 +602,22 @@ class LocalOutlierFactor(BaseTransformer):
|
|
591
602
|
# each row containing a list of values.
|
592
603
|
expected_dtype = "ARRAY"
|
593
604
|
|
605
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
606
|
+
if expected_dtype == "":
|
607
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
608
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
609
|
+
expected_dtype = "ARRAY"
|
610
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
611
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
612
|
+
expected_dtype = "ARRAY"
|
613
|
+
else:
|
614
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
615
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
616
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
617
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
618
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
619
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
620
|
+
|
594
621
|
output_df = self._batch_inference(
|
595
622
|
dataset=dataset,
|
596
623
|
inference_method="transform",
|
@@ -606,8 +633,8 @@ class LocalOutlierFactor(BaseTransformer):
|
|
606
633
|
|
607
634
|
return output_df
|
608
635
|
|
609
|
-
@available_if(
|
610
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
636
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
637
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
611
638
|
""" Fit the model to the training set X and return the labels
|
612
639
|
For more details on this function, see [sklearn.neighbors.LocalOutlierFactor.fit_predict]
|
613
640
|
(https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.LocalOutlierFactor.html#sklearn.neighbors.LocalOutlierFactor.fit_predict)
|
@@ -622,13 +649,21 @@ class LocalOutlierFactor(BaseTransformer):
|
|
622
649
|
Returns:
|
623
650
|
Predicted dataset.
|
624
651
|
"""
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
652
|
+
self.fit(dataset)
|
653
|
+
assert self._sklearn_object is not None
|
654
|
+
return self._sklearn_object.labels_
|
655
|
+
|
656
|
+
|
657
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
658
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
659
|
+
"""
|
660
|
+
Returns:
|
661
|
+
Transformed dataset.
|
662
|
+
"""
|
663
|
+
self.fit(dataset)
|
664
|
+
assert self._sklearn_object is not None
|
665
|
+
return self._sklearn_object.embedding_
|
666
|
+
|
632
667
|
|
633
668
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
634
669
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class NearestCentroid(BaseTransformer):
|
58
70
|
r"""Nearest centroid classifier
|
59
71
|
For more details on this class, see [sklearn.neighbors.NearestCentroid]
|
@@ -144,7 +156,9 @@ class NearestCentroid(BaseTransformer):
|
|
144
156
|
self.set_label_cols(label_cols)
|
145
157
|
self.set_passthrough_cols(passthrough_cols)
|
146
158
|
self.set_drop_input_cols(drop_input_cols)
|
147
|
-
self.set_sample_weight_col(sample_weight_col)
|
159
|
+
self.set_sample_weight_col(sample_weight_col)
|
160
|
+
self._use_external_memory_version = False
|
161
|
+
self._batch_size = -1
|
148
162
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
149
163
|
|
150
164
|
self._deps = list(deps)
|
@@ -221,11 +235,6 @@ class NearestCentroid(BaseTransformer):
|
|
221
235
|
if isinstance(dataset, DataFrame):
|
222
236
|
session = dataset._session
|
223
237
|
assert session is not None # keep mypy happy
|
224
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
225
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
226
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
227
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
228
|
-
|
229
238
|
# Specify input columns so column pruning will be enforced
|
230
239
|
selected_cols = self._get_active_columns()
|
231
240
|
if len(selected_cols) > 0:
|
@@ -253,7 +262,9 @@ class NearestCentroid(BaseTransformer):
|
|
253
262
|
label_cols=self.label_cols,
|
254
263
|
sample_weight_col=self.sample_weight_col,
|
255
264
|
autogenerated=self._autogenerated,
|
256
|
-
subproject=_SUBPROJECT
|
265
|
+
subproject=_SUBPROJECT,
|
266
|
+
use_external_memory_version=self._use_external_memory_version,
|
267
|
+
batch_size=self._batch_size,
|
257
268
|
)
|
258
269
|
self._sklearn_object = model_trainer.train()
|
259
270
|
self._is_fitted = True
|
@@ -524,6 +535,22 @@ class NearestCentroid(BaseTransformer):
|
|
524
535
|
# each row containing a list of values.
|
525
536
|
expected_dtype = "ARRAY"
|
526
537
|
|
538
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
539
|
+
if expected_dtype == "":
|
540
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
541
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
542
|
+
expected_dtype = "ARRAY"
|
543
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
544
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
545
|
+
expected_dtype = "ARRAY"
|
546
|
+
else:
|
547
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
548
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
549
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
550
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
551
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
552
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
553
|
+
|
527
554
|
output_df = self._batch_inference(
|
528
555
|
dataset=dataset,
|
529
556
|
inference_method="transform",
|
@@ -539,8 +566,8 @@ class NearestCentroid(BaseTransformer):
|
|
539
566
|
|
540
567
|
return output_df
|
541
568
|
|
542
|
-
@available_if(
|
543
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
569
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
570
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
544
571
|
""" Method not supported for this class.
|
545
572
|
|
546
573
|
|
@@ -553,13 +580,21 @@ class NearestCentroid(BaseTransformer):
|
|
553
580
|
Returns:
|
554
581
|
Predicted dataset.
|
555
582
|
"""
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
583
|
+
self.fit(dataset)
|
584
|
+
assert self._sklearn_object is not None
|
585
|
+
return self._sklearn_object.labels_
|
586
|
+
|
587
|
+
|
588
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
589
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
590
|
+
"""
|
591
|
+
Returns:
|
592
|
+
Transformed dataset.
|
593
|
+
"""
|
594
|
+
self.fit(dataset)
|
595
|
+
assert self._sklearn_object is not None
|
596
|
+
return self._sklearn_object.embedding_
|
597
|
+
|
563
598
|
|
564
599
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
565
600
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|