snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
|
|
55
55
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
|
56
56
|
|
57
57
|
|
58
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
59
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
60
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
61
|
+
return check
|
62
|
+
|
63
|
+
|
64
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
+
return check
|
68
|
+
|
69
|
+
|
58
70
|
class SelectPercentile(BaseTransformer):
|
59
71
|
r"""Select features according to a percentile of the highest scores
|
60
72
|
For more details on this class, see [sklearn.feature_selection.SelectPercentile]
|
@@ -136,7 +148,9 @@ class SelectPercentile(BaseTransformer):
|
|
136
148
|
self.set_label_cols(label_cols)
|
137
149
|
self.set_passthrough_cols(passthrough_cols)
|
138
150
|
self.set_drop_input_cols(drop_input_cols)
|
139
|
-
self.set_sample_weight_col(sample_weight_col)
|
151
|
+
self.set_sample_weight_col(sample_weight_col)
|
152
|
+
self._use_external_memory_version = False
|
153
|
+
self._batch_size = -1
|
140
154
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
141
155
|
|
142
156
|
self._deps = list(deps)
|
@@ -213,11 +227,6 @@ class SelectPercentile(BaseTransformer):
|
|
213
227
|
if isinstance(dataset, DataFrame):
|
214
228
|
session = dataset._session
|
215
229
|
assert session is not None # keep mypy happy
|
216
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
217
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
218
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
219
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
220
|
-
|
221
230
|
# Specify input columns so column pruning will be enforced
|
222
231
|
selected_cols = self._get_active_columns()
|
223
232
|
if len(selected_cols) > 0:
|
@@ -245,7 +254,9 @@ class SelectPercentile(BaseTransformer):
|
|
245
254
|
label_cols=self.label_cols,
|
246
255
|
sample_weight_col=self.sample_weight_col,
|
247
256
|
autogenerated=self._autogenerated,
|
248
|
-
subproject=_SUBPROJECT
|
257
|
+
subproject=_SUBPROJECT,
|
258
|
+
use_external_memory_version=self._use_external_memory_version,
|
259
|
+
batch_size=self._batch_size,
|
249
260
|
)
|
250
261
|
self._sklearn_object = model_trainer.train()
|
251
262
|
self._is_fitted = True
|
@@ -516,6 +527,22 @@ class SelectPercentile(BaseTransformer):
|
|
516
527
|
# each row containing a list of values.
|
517
528
|
expected_dtype = "ARRAY"
|
518
529
|
|
530
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
531
|
+
if expected_dtype == "":
|
532
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
533
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
534
|
+
expected_dtype = "ARRAY"
|
535
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
536
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
537
|
+
expected_dtype = "ARRAY"
|
538
|
+
else:
|
539
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
540
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
541
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
542
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
543
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
544
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
545
|
+
|
519
546
|
output_df = self._batch_inference(
|
520
547
|
dataset=dataset,
|
521
548
|
inference_method="transform",
|
@@ -531,8 +558,8 @@ class SelectPercentile(BaseTransformer):
|
|
531
558
|
|
532
559
|
return output_df
|
533
560
|
|
534
|
-
@available_if(
|
535
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
561
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
562
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
536
563
|
""" Method not supported for this class.
|
537
564
|
|
538
565
|
|
@@ -545,13 +572,21 @@ class SelectPercentile(BaseTransformer):
|
|
545
572
|
Returns:
|
546
573
|
Predicted dataset.
|
547
574
|
"""
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
575
|
+
self.fit(dataset)
|
576
|
+
assert self._sklearn_object is not None
|
577
|
+
return self._sklearn_object.labels_
|
578
|
+
|
579
|
+
|
580
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
581
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
582
|
+
"""
|
583
|
+
Returns:
|
584
|
+
Transformed dataset.
|
585
|
+
"""
|
586
|
+
self.fit(dataset)
|
587
|
+
assert self._sklearn_object is not None
|
588
|
+
return self._sklearn_object.embedding_
|
589
|
+
|
555
590
|
|
556
591
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
557
592
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SequentialFeatureSelector(BaseTransformer):
|
58
70
|
r"""Transformer that performs Sequential Feature Selection
|
59
71
|
For more details on this class, see [sklearn.feature_selection.SequentialFeatureSelector]
|
@@ -189,7 +201,9 @@ class SequentialFeatureSelector(BaseTransformer):
|
|
189
201
|
self.set_label_cols(label_cols)
|
190
202
|
self.set_passthrough_cols(passthrough_cols)
|
191
203
|
self.set_drop_input_cols(drop_input_cols)
|
192
|
-
self.set_sample_weight_col(sample_weight_col)
|
204
|
+
self.set_sample_weight_col(sample_weight_col)
|
205
|
+
self._use_external_memory_version = False
|
206
|
+
self._batch_size = -1
|
193
207
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
194
208
|
deps = deps | gather_dependencies(estimator)
|
195
209
|
self._deps = list(deps)
|
@@ -271,11 +285,6 @@ class SequentialFeatureSelector(BaseTransformer):
|
|
271
285
|
if isinstance(dataset, DataFrame):
|
272
286
|
session = dataset._session
|
273
287
|
assert session is not None # keep mypy happy
|
274
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
275
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
276
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
277
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
278
|
-
|
279
288
|
# Specify input columns so column pruning will be enforced
|
280
289
|
selected_cols = self._get_active_columns()
|
281
290
|
if len(selected_cols) > 0:
|
@@ -303,7 +312,9 @@ class SequentialFeatureSelector(BaseTransformer):
|
|
303
312
|
label_cols=self.label_cols,
|
304
313
|
sample_weight_col=self.sample_weight_col,
|
305
314
|
autogenerated=self._autogenerated,
|
306
|
-
subproject=_SUBPROJECT
|
315
|
+
subproject=_SUBPROJECT,
|
316
|
+
use_external_memory_version=self._use_external_memory_version,
|
317
|
+
batch_size=self._batch_size,
|
307
318
|
)
|
308
319
|
self._sklearn_object = model_trainer.train()
|
309
320
|
self._is_fitted = True
|
@@ -574,6 +585,22 @@ class SequentialFeatureSelector(BaseTransformer):
|
|
574
585
|
# each row containing a list of values.
|
575
586
|
expected_dtype = "ARRAY"
|
576
587
|
|
588
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
589
|
+
if expected_dtype == "":
|
590
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
591
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
592
|
+
expected_dtype = "ARRAY"
|
593
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
594
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
595
|
+
expected_dtype = "ARRAY"
|
596
|
+
else:
|
597
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
598
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
599
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
600
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
601
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
602
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
603
|
+
|
577
604
|
output_df = self._batch_inference(
|
578
605
|
dataset=dataset,
|
579
606
|
inference_method="transform",
|
@@ -589,8 +616,8 @@ class SequentialFeatureSelector(BaseTransformer):
|
|
589
616
|
|
590
617
|
return output_df
|
591
618
|
|
592
|
-
@available_if(
|
593
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
619
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
620
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
594
621
|
""" Method not supported for this class.
|
595
622
|
|
596
623
|
|
@@ -603,13 +630,21 @@ class SequentialFeatureSelector(BaseTransformer):
|
|
603
630
|
Returns:
|
604
631
|
Predicted dataset.
|
605
632
|
"""
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
633
|
+
self.fit(dataset)
|
634
|
+
assert self._sklearn_object is not None
|
635
|
+
return self._sklearn_object.labels_
|
636
|
+
|
637
|
+
|
638
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
639
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
640
|
+
"""
|
641
|
+
Returns:
|
642
|
+
Transformed dataset.
|
643
|
+
"""
|
644
|
+
self.fit(dataset)
|
645
|
+
assert self._sklearn_object is not None
|
646
|
+
return self._sklearn_object.embedding_
|
647
|
+
|
613
648
|
|
614
649
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
615
650
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class VarianceThreshold(BaseTransformer):
|
58
70
|
r"""Feature selector that removes all low-variance features
|
59
71
|
For more details on this class, see [sklearn.feature_selection.VarianceThreshold]
|
@@ -128,7 +140,9 @@ class VarianceThreshold(BaseTransformer):
|
|
128
140
|
self.set_label_cols(label_cols)
|
129
141
|
self.set_passthrough_cols(passthrough_cols)
|
130
142
|
self.set_drop_input_cols(drop_input_cols)
|
131
|
-
self.set_sample_weight_col(sample_weight_col)
|
143
|
+
self.set_sample_weight_col(sample_weight_col)
|
144
|
+
self._use_external_memory_version = False
|
145
|
+
self._batch_size = -1
|
132
146
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
133
147
|
|
134
148
|
self._deps = list(deps)
|
@@ -204,11 +218,6 @@ class VarianceThreshold(BaseTransformer):
|
|
204
218
|
if isinstance(dataset, DataFrame):
|
205
219
|
session = dataset._session
|
206
220
|
assert session is not None # keep mypy happy
|
207
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
208
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
209
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
210
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
211
|
-
|
212
221
|
# Specify input columns so column pruning will be enforced
|
213
222
|
selected_cols = self._get_active_columns()
|
214
223
|
if len(selected_cols) > 0:
|
@@ -236,7 +245,9 @@ class VarianceThreshold(BaseTransformer):
|
|
236
245
|
label_cols=self.label_cols,
|
237
246
|
sample_weight_col=self.sample_weight_col,
|
238
247
|
autogenerated=self._autogenerated,
|
239
|
-
subproject=_SUBPROJECT
|
248
|
+
subproject=_SUBPROJECT,
|
249
|
+
use_external_memory_version=self._use_external_memory_version,
|
250
|
+
batch_size=self._batch_size,
|
240
251
|
)
|
241
252
|
self._sklearn_object = model_trainer.train()
|
242
253
|
self._is_fitted = True
|
@@ -507,6 +518,22 @@ class VarianceThreshold(BaseTransformer):
|
|
507
518
|
# each row containing a list of values.
|
508
519
|
expected_dtype = "ARRAY"
|
509
520
|
|
521
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
522
|
+
if expected_dtype == "":
|
523
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
524
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
525
|
+
expected_dtype = "ARRAY"
|
526
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
527
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
528
|
+
expected_dtype = "ARRAY"
|
529
|
+
else:
|
530
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
531
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
532
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
533
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
534
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
535
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
536
|
+
|
510
537
|
output_df = self._batch_inference(
|
511
538
|
dataset=dataset,
|
512
539
|
inference_method="transform",
|
@@ -522,8 +549,8 @@ class VarianceThreshold(BaseTransformer):
|
|
522
549
|
|
523
550
|
return output_df
|
524
551
|
|
525
|
-
@available_if(
|
526
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
552
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
553
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
527
554
|
""" Method not supported for this class.
|
528
555
|
|
529
556
|
|
@@ -536,13 +563,21 @@ class VarianceThreshold(BaseTransformer):
|
|
536
563
|
Returns:
|
537
564
|
Predicted dataset.
|
538
565
|
"""
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
566
|
+
self.fit(dataset)
|
567
|
+
assert self._sklearn_object is not None
|
568
|
+
return self._sklearn_object.labels_
|
569
|
+
|
570
|
+
|
571
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
572
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
573
|
+
"""
|
574
|
+
Returns:
|
575
|
+
Transformed dataset.
|
576
|
+
"""
|
577
|
+
self.fit(dataset)
|
578
|
+
assert self._sklearn_object is not None
|
579
|
+
return self._sklearn_object.embedding_
|
580
|
+
|
546
581
|
|
547
582
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
548
583
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.gaussian_process".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GaussianProcessClassifier(BaseTransformer):
|
58
70
|
r"""Gaussian process classification (GPC) based on Laplace approximation
|
59
71
|
For more details on this class, see [sklearn.gaussian_process.GaussianProcessClassifier]
|
@@ -215,7 +227,9 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
215
227
|
self.set_label_cols(label_cols)
|
216
228
|
self.set_passthrough_cols(passthrough_cols)
|
217
229
|
self.set_drop_input_cols(drop_input_cols)
|
218
|
-
self.set_sample_weight_col(sample_weight_col)
|
230
|
+
self.set_sample_weight_col(sample_weight_col)
|
231
|
+
self._use_external_memory_version = False
|
232
|
+
self._batch_size = -1
|
219
233
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
220
234
|
|
221
235
|
self._deps = list(deps)
|
@@ -299,11 +313,6 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
299
313
|
if isinstance(dataset, DataFrame):
|
300
314
|
session = dataset._session
|
301
315
|
assert session is not None # keep mypy happy
|
302
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
303
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
304
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
305
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
306
|
-
|
307
316
|
# Specify input columns so column pruning will be enforced
|
308
317
|
selected_cols = self._get_active_columns()
|
309
318
|
if len(selected_cols) > 0:
|
@@ -331,7 +340,9 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
331
340
|
label_cols=self.label_cols,
|
332
341
|
sample_weight_col=self.sample_weight_col,
|
333
342
|
autogenerated=self._autogenerated,
|
334
|
-
subproject=_SUBPROJECT
|
343
|
+
subproject=_SUBPROJECT,
|
344
|
+
use_external_memory_version=self._use_external_memory_version,
|
345
|
+
batch_size=self._batch_size,
|
335
346
|
)
|
336
347
|
self._sklearn_object = model_trainer.train()
|
337
348
|
self._is_fitted = True
|
@@ -602,6 +613,22 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
602
613
|
# each row containing a list of values.
|
603
614
|
expected_dtype = "ARRAY"
|
604
615
|
|
616
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
617
|
+
if expected_dtype == "":
|
618
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
619
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
620
|
+
expected_dtype = "ARRAY"
|
621
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
622
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
623
|
+
expected_dtype = "ARRAY"
|
624
|
+
else:
|
625
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
626
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
627
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
628
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
629
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
630
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
631
|
+
|
605
632
|
output_df = self._batch_inference(
|
606
633
|
dataset=dataset,
|
607
634
|
inference_method="transform",
|
@@ -617,8 +644,8 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
617
644
|
|
618
645
|
return output_df
|
619
646
|
|
620
|
-
@available_if(
|
621
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
647
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
648
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
622
649
|
""" Method not supported for this class.
|
623
650
|
|
624
651
|
|
@@ -631,13 +658,21 @@ class GaussianProcessClassifier(BaseTransformer):
|
|
631
658
|
Returns:
|
632
659
|
Predicted dataset.
|
633
660
|
"""
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
661
|
+
self.fit(dataset)
|
662
|
+
assert self._sklearn_object is not None
|
663
|
+
return self._sklearn_object.labels_
|
664
|
+
|
665
|
+
|
666
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
667
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
668
|
+
"""
|
669
|
+
Returns:
|
670
|
+
Transformed dataset.
|
671
|
+
"""
|
672
|
+
self.fit(dataset)
|
673
|
+
assert self._sklearn_object is not None
|
674
|
+
return self._sklearn_object.embedding_
|
675
|
+
|
641
676
|
|
642
677
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
643
678
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.gaussian_process".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GaussianProcessRegressor(BaseTransformer):
|
58
70
|
r"""Gaussian process regression (GPR)
|
59
71
|
For more details on this class, see [sklearn.gaussian_process.GaussianProcessRegressor]
|
@@ -207,7 +219,9 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
207
219
|
self.set_label_cols(label_cols)
|
208
220
|
self.set_passthrough_cols(passthrough_cols)
|
209
221
|
self.set_drop_input_cols(drop_input_cols)
|
210
|
-
self.set_sample_weight_col(sample_weight_col)
|
222
|
+
self.set_sample_weight_col(sample_weight_col)
|
223
|
+
self._use_external_memory_version = False
|
224
|
+
self._batch_size = -1
|
211
225
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
212
226
|
|
213
227
|
self._deps = list(deps)
|
@@ -290,11 +304,6 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
290
304
|
if isinstance(dataset, DataFrame):
|
291
305
|
session = dataset._session
|
292
306
|
assert session is not None # keep mypy happy
|
293
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
294
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
295
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
296
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
297
|
-
|
298
307
|
# Specify input columns so column pruning will be enforced
|
299
308
|
selected_cols = self._get_active_columns()
|
300
309
|
if len(selected_cols) > 0:
|
@@ -322,7 +331,9 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
322
331
|
label_cols=self.label_cols,
|
323
332
|
sample_weight_col=self.sample_weight_col,
|
324
333
|
autogenerated=self._autogenerated,
|
325
|
-
subproject=_SUBPROJECT
|
334
|
+
subproject=_SUBPROJECT,
|
335
|
+
use_external_memory_version=self._use_external_memory_version,
|
336
|
+
batch_size=self._batch_size,
|
326
337
|
)
|
327
338
|
self._sklearn_object = model_trainer.train()
|
328
339
|
self._is_fitted = True
|
@@ -593,6 +604,22 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
593
604
|
# each row containing a list of values.
|
594
605
|
expected_dtype = "ARRAY"
|
595
606
|
|
607
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
608
|
+
if expected_dtype == "":
|
609
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
610
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
611
|
+
expected_dtype = "ARRAY"
|
612
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
613
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
614
|
+
expected_dtype = "ARRAY"
|
615
|
+
else:
|
616
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
617
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
618
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
619
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
620
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
621
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
622
|
+
|
596
623
|
output_df = self._batch_inference(
|
597
624
|
dataset=dataset,
|
598
625
|
inference_method="transform",
|
@@ -608,8 +635,8 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
608
635
|
|
609
636
|
return output_df
|
610
637
|
|
611
|
-
@available_if(
|
612
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
638
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
639
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
613
640
|
""" Method not supported for this class.
|
614
641
|
|
615
642
|
|
@@ -622,13 +649,21 @@ class GaussianProcessRegressor(BaseTransformer):
|
|
622
649
|
Returns:
|
623
650
|
Predicted dataset.
|
624
651
|
"""
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
652
|
+
self.fit(dataset)
|
653
|
+
assert self._sklearn_object is not None
|
654
|
+
return self._sklearn_object.labels_
|
655
|
+
|
656
|
+
|
657
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
658
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
659
|
+
"""
|
660
|
+
Returns:
|
661
|
+
Transformed dataset.
|
662
|
+
"""
|
663
|
+
self.fit(dataset)
|
664
|
+
assert self._sklearn_object is not None
|
665
|
+
return self._sklearn_object.embedding_
|
666
|
+
|
632
667
|
|
633
668
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
634
669
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|