snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
|
|
55
55
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
|
56
56
|
|
57
57
|
|
58
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
59
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
60
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
61
|
+
return check
|
62
|
+
|
63
|
+
|
64
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
+
return check
|
68
|
+
|
69
|
+
|
58
70
|
class GenericUnivariateSelect(BaseTransformer):
|
59
71
|
r"""Univariate feature selector with configurable strategy
|
60
72
|
For more details on this class, see [sklearn.feature_selection.GenericUnivariateSelect]
|
@@ -139,7 +151,9 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
139
151
|
self.set_label_cols(label_cols)
|
140
152
|
self.set_passthrough_cols(passthrough_cols)
|
141
153
|
self.set_drop_input_cols(drop_input_cols)
|
142
|
-
self.set_sample_weight_col(sample_weight_col)
|
154
|
+
self.set_sample_weight_col(sample_weight_col)
|
155
|
+
self._use_external_memory_version = False
|
156
|
+
self._batch_size = -1
|
143
157
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
144
158
|
|
145
159
|
self._deps = list(deps)
|
@@ -217,11 +231,6 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
217
231
|
if isinstance(dataset, DataFrame):
|
218
232
|
session = dataset._session
|
219
233
|
assert session is not None # keep mypy happy
|
220
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
221
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
222
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
223
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
224
|
-
|
225
234
|
# Specify input columns so column pruning will be enforced
|
226
235
|
selected_cols = self._get_active_columns()
|
227
236
|
if len(selected_cols) > 0:
|
@@ -249,7 +258,9 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
249
258
|
label_cols=self.label_cols,
|
250
259
|
sample_weight_col=self.sample_weight_col,
|
251
260
|
autogenerated=self._autogenerated,
|
252
|
-
subproject=_SUBPROJECT
|
261
|
+
subproject=_SUBPROJECT,
|
262
|
+
use_external_memory_version=self._use_external_memory_version,
|
263
|
+
batch_size=self._batch_size,
|
253
264
|
)
|
254
265
|
self._sklearn_object = model_trainer.train()
|
255
266
|
self._is_fitted = True
|
@@ -520,6 +531,22 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
520
531
|
# each row containing a list of values.
|
521
532
|
expected_dtype = "ARRAY"
|
522
533
|
|
534
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
535
|
+
if expected_dtype == "":
|
536
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
537
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
538
|
+
expected_dtype = "ARRAY"
|
539
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
540
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
541
|
+
expected_dtype = "ARRAY"
|
542
|
+
else:
|
543
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
544
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
545
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
546
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
547
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
548
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
549
|
+
|
523
550
|
output_df = self._batch_inference(
|
524
551
|
dataset=dataset,
|
525
552
|
inference_method="transform",
|
@@ -535,8 +562,8 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
535
562
|
|
536
563
|
return output_df
|
537
564
|
|
538
|
-
@available_if(
|
539
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
565
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
566
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
540
567
|
""" Method not supported for this class.
|
541
568
|
|
542
569
|
|
@@ -549,13 +576,21 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
549
576
|
Returns:
|
550
577
|
Predicted dataset.
|
551
578
|
"""
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
579
|
+
self.fit(dataset)
|
580
|
+
assert self._sklearn_object is not None
|
581
|
+
return self._sklearn_object.labels_
|
582
|
+
|
583
|
+
|
584
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
585
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
586
|
+
"""
|
587
|
+
Returns:
|
588
|
+
Transformed dataset.
|
589
|
+
"""
|
590
|
+
self.fit(dataset)
|
591
|
+
assert self._sklearn_object is not None
|
592
|
+
return self._sklearn_object.embedding_
|
593
|
+
|
559
594
|
|
560
595
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
561
596
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
|
|
55
55
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
|
56
56
|
|
57
57
|
|
58
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
59
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
60
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
61
|
+
return check
|
62
|
+
|
63
|
+
|
64
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
+
return check
|
68
|
+
|
69
|
+
|
58
70
|
class SelectFdr(BaseTransformer):
|
59
71
|
r"""Filter: Select the p-values for an estimated false discovery rate
|
60
72
|
For more details on this class, see [sklearn.feature_selection.SelectFdr]
|
@@ -136,7 +148,9 @@ class SelectFdr(BaseTransformer):
|
|
136
148
|
self.set_label_cols(label_cols)
|
137
149
|
self.set_passthrough_cols(passthrough_cols)
|
138
150
|
self.set_drop_input_cols(drop_input_cols)
|
139
|
-
self.set_sample_weight_col(sample_weight_col)
|
151
|
+
self.set_sample_weight_col(sample_weight_col)
|
152
|
+
self._use_external_memory_version = False
|
153
|
+
self._batch_size = -1
|
140
154
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
141
155
|
|
142
156
|
self._deps = list(deps)
|
@@ -213,11 +227,6 @@ class SelectFdr(BaseTransformer):
|
|
213
227
|
if isinstance(dataset, DataFrame):
|
214
228
|
session = dataset._session
|
215
229
|
assert session is not None # keep mypy happy
|
216
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
217
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
218
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
219
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
220
|
-
|
221
230
|
# Specify input columns so column pruning will be enforced
|
222
231
|
selected_cols = self._get_active_columns()
|
223
232
|
if len(selected_cols) > 0:
|
@@ -245,7 +254,9 @@ class SelectFdr(BaseTransformer):
|
|
245
254
|
label_cols=self.label_cols,
|
246
255
|
sample_weight_col=self.sample_weight_col,
|
247
256
|
autogenerated=self._autogenerated,
|
248
|
-
subproject=_SUBPROJECT
|
257
|
+
subproject=_SUBPROJECT,
|
258
|
+
use_external_memory_version=self._use_external_memory_version,
|
259
|
+
batch_size=self._batch_size,
|
249
260
|
)
|
250
261
|
self._sklearn_object = model_trainer.train()
|
251
262
|
self._is_fitted = True
|
@@ -516,6 +527,22 @@ class SelectFdr(BaseTransformer):
|
|
516
527
|
# each row containing a list of values.
|
517
528
|
expected_dtype = "ARRAY"
|
518
529
|
|
530
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
531
|
+
if expected_dtype == "":
|
532
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
533
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
534
|
+
expected_dtype = "ARRAY"
|
535
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
536
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
537
|
+
expected_dtype = "ARRAY"
|
538
|
+
else:
|
539
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
540
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
541
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
542
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
543
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
544
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
545
|
+
|
519
546
|
output_df = self._batch_inference(
|
520
547
|
dataset=dataset,
|
521
548
|
inference_method="transform",
|
@@ -531,8 +558,8 @@ class SelectFdr(BaseTransformer):
|
|
531
558
|
|
532
559
|
return output_df
|
533
560
|
|
534
|
-
@available_if(
|
535
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
561
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
562
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
536
563
|
""" Method not supported for this class.
|
537
564
|
|
538
565
|
|
@@ -545,13 +572,21 @@ class SelectFdr(BaseTransformer):
|
|
545
572
|
Returns:
|
546
573
|
Predicted dataset.
|
547
574
|
"""
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
575
|
+
self.fit(dataset)
|
576
|
+
assert self._sklearn_object is not None
|
577
|
+
return self._sklearn_object.labels_
|
578
|
+
|
579
|
+
|
580
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
581
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
582
|
+
"""
|
583
|
+
Returns:
|
584
|
+
Transformed dataset.
|
585
|
+
"""
|
586
|
+
self.fit(dataset)
|
587
|
+
assert self._sklearn_object is not None
|
588
|
+
return self._sklearn_object.embedding_
|
589
|
+
|
555
590
|
|
556
591
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
557
592
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
|
|
55
55
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
|
56
56
|
|
57
57
|
|
58
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
59
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
60
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
61
|
+
return check
|
62
|
+
|
63
|
+
|
64
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
+
return check
|
68
|
+
|
69
|
+
|
58
70
|
class SelectFpr(BaseTransformer):
|
59
71
|
r"""Filter: Select the pvalues below alpha based on a FPR test
|
60
72
|
For more details on this class, see [sklearn.feature_selection.SelectFpr]
|
@@ -136,7 +148,9 @@ class SelectFpr(BaseTransformer):
|
|
136
148
|
self.set_label_cols(label_cols)
|
137
149
|
self.set_passthrough_cols(passthrough_cols)
|
138
150
|
self.set_drop_input_cols(drop_input_cols)
|
139
|
-
self.set_sample_weight_col(sample_weight_col)
|
151
|
+
self.set_sample_weight_col(sample_weight_col)
|
152
|
+
self._use_external_memory_version = False
|
153
|
+
self._batch_size = -1
|
140
154
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
141
155
|
|
142
156
|
self._deps = list(deps)
|
@@ -213,11 +227,6 @@ class SelectFpr(BaseTransformer):
|
|
213
227
|
if isinstance(dataset, DataFrame):
|
214
228
|
session = dataset._session
|
215
229
|
assert session is not None # keep mypy happy
|
216
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
217
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
218
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
219
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
220
|
-
|
221
230
|
# Specify input columns so column pruning will be enforced
|
222
231
|
selected_cols = self._get_active_columns()
|
223
232
|
if len(selected_cols) > 0:
|
@@ -245,7 +254,9 @@ class SelectFpr(BaseTransformer):
|
|
245
254
|
label_cols=self.label_cols,
|
246
255
|
sample_weight_col=self.sample_weight_col,
|
247
256
|
autogenerated=self._autogenerated,
|
248
|
-
subproject=_SUBPROJECT
|
257
|
+
subproject=_SUBPROJECT,
|
258
|
+
use_external_memory_version=self._use_external_memory_version,
|
259
|
+
batch_size=self._batch_size,
|
249
260
|
)
|
250
261
|
self._sklearn_object = model_trainer.train()
|
251
262
|
self._is_fitted = True
|
@@ -516,6 +527,22 @@ class SelectFpr(BaseTransformer):
|
|
516
527
|
# each row containing a list of values.
|
517
528
|
expected_dtype = "ARRAY"
|
518
529
|
|
530
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
531
|
+
if expected_dtype == "":
|
532
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
533
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
534
|
+
expected_dtype = "ARRAY"
|
535
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
536
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
537
|
+
expected_dtype = "ARRAY"
|
538
|
+
else:
|
539
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
540
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
541
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
542
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
543
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
544
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
545
|
+
|
519
546
|
output_df = self._batch_inference(
|
520
547
|
dataset=dataset,
|
521
548
|
inference_method="transform",
|
@@ -531,8 +558,8 @@ class SelectFpr(BaseTransformer):
|
|
531
558
|
|
532
559
|
return output_df
|
533
560
|
|
534
|
-
@available_if(
|
535
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
561
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
562
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
536
563
|
""" Method not supported for this class.
|
537
564
|
|
538
565
|
|
@@ -545,13 +572,21 @@ class SelectFpr(BaseTransformer):
|
|
545
572
|
Returns:
|
546
573
|
Predicted dataset.
|
547
574
|
"""
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
575
|
+
self.fit(dataset)
|
576
|
+
assert self._sklearn_object is not None
|
577
|
+
return self._sklearn_object.labels_
|
578
|
+
|
579
|
+
|
580
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
581
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
582
|
+
"""
|
583
|
+
Returns:
|
584
|
+
Transformed dataset.
|
585
|
+
"""
|
586
|
+
self.fit(dataset)
|
587
|
+
assert self._sklearn_object is not None
|
588
|
+
return self._sklearn_object.embedding_
|
589
|
+
|
555
590
|
|
556
591
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
557
592
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
|
|
55
55
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
|
56
56
|
|
57
57
|
|
58
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
59
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
60
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
61
|
+
return check
|
62
|
+
|
63
|
+
|
64
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
+
return check
|
68
|
+
|
69
|
+
|
58
70
|
class SelectFwe(BaseTransformer):
|
59
71
|
r"""Filter: Select the p-values corresponding to Family-wise error rate
|
60
72
|
For more details on this class, see [sklearn.feature_selection.SelectFwe]
|
@@ -136,7 +148,9 @@ class SelectFwe(BaseTransformer):
|
|
136
148
|
self.set_label_cols(label_cols)
|
137
149
|
self.set_passthrough_cols(passthrough_cols)
|
138
150
|
self.set_drop_input_cols(drop_input_cols)
|
139
|
-
self.set_sample_weight_col(sample_weight_col)
|
151
|
+
self.set_sample_weight_col(sample_weight_col)
|
152
|
+
self._use_external_memory_version = False
|
153
|
+
self._batch_size = -1
|
140
154
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
141
155
|
|
142
156
|
self._deps = list(deps)
|
@@ -213,11 +227,6 @@ class SelectFwe(BaseTransformer):
|
|
213
227
|
if isinstance(dataset, DataFrame):
|
214
228
|
session = dataset._session
|
215
229
|
assert session is not None # keep mypy happy
|
216
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
217
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
218
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
219
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
220
|
-
|
221
230
|
# Specify input columns so column pruning will be enforced
|
222
231
|
selected_cols = self._get_active_columns()
|
223
232
|
if len(selected_cols) > 0:
|
@@ -245,7 +254,9 @@ class SelectFwe(BaseTransformer):
|
|
245
254
|
label_cols=self.label_cols,
|
246
255
|
sample_weight_col=self.sample_weight_col,
|
247
256
|
autogenerated=self._autogenerated,
|
248
|
-
subproject=_SUBPROJECT
|
257
|
+
subproject=_SUBPROJECT,
|
258
|
+
use_external_memory_version=self._use_external_memory_version,
|
259
|
+
batch_size=self._batch_size,
|
249
260
|
)
|
250
261
|
self._sklearn_object = model_trainer.train()
|
251
262
|
self._is_fitted = True
|
@@ -516,6 +527,22 @@ class SelectFwe(BaseTransformer):
|
|
516
527
|
# each row containing a list of values.
|
517
528
|
expected_dtype = "ARRAY"
|
518
529
|
|
530
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
531
|
+
if expected_dtype == "":
|
532
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
533
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
534
|
+
expected_dtype = "ARRAY"
|
535
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
536
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
537
|
+
expected_dtype = "ARRAY"
|
538
|
+
else:
|
539
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
540
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
541
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
542
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
543
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
544
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
545
|
+
|
519
546
|
output_df = self._batch_inference(
|
520
547
|
dataset=dataset,
|
521
548
|
inference_method="transform",
|
@@ -531,8 +558,8 @@ class SelectFwe(BaseTransformer):
|
|
531
558
|
|
532
559
|
return output_df
|
533
560
|
|
534
|
-
@available_if(
|
535
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
561
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
562
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
536
563
|
""" Method not supported for this class.
|
537
564
|
|
538
565
|
|
@@ -545,13 +572,21 @@ class SelectFwe(BaseTransformer):
|
|
545
572
|
Returns:
|
546
573
|
Predicted dataset.
|
547
574
|
"""
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
575
|
+
self.fit(dataset)
|
576
|
+
assert self._sklearn_object is not None
|
577
|
+
return self._sklearn_object.labels_
|
578
|
+
|
579
|
+
|
580
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
581
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
582
|
+
"""
|
583
|
+
Returns:
|
584
|
+
Transformed dataset.
|
585
|
+
"""
|
586
|
+
self.fit(dataset)
|
587
|
+
assert self._sklearn_object is not None
|
588
|
+
return self._sklearn_object.embedding_
|
589
|
+
|
555
590
|
|
556
591
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
557
592
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
|
|
55
55
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".replace("sklearn.", "").split("_")])
|
56
56
|
|
57
57
|
|
58
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
59
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
60
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
61
|
+
return check
|
62
|
+
|
63
|
+
|
64
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
+
return check
|
68
|
+
|
69
|
+
|
58
70
|
class SelectKBest(BaseTransformer):
|
59
71
|
r"""Select features according to the k highest scores
|
60
72
|
For more details on this class, see [sklearn.feature_selection.SelectKBest]
|
@@ -137,7 +149,9 @@ class SelectKBest(BaseTransformer):
|
|
137
149
|
self.set_label_cols(label_cols)
|
138
150
|
self.set_passthrough_cols(passthrough_cols)
|
139
151
|
self.set_drop_input_cols(drop_input_cols)
|
140
|
-
self.set_sample_weight_col(sample_weight_col)
|
152
|
+
self.set_sample_weight_col(sample_weight_col)
|
153
|
+
self._use_external_memory_version = False
|
154
|
+
self._batch_size = -1
|
141
155
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
142
156
|
|
143
157
|
self._deps = list(deps)
|
@@ -214,11 +228,6 @@ class SelectKBest(BaseTransformer):
|
|
214
228
|
if isinstance(dataset, DataFrame):
|
215
229
|
session = dataset._session
|
216
230
|
assert session is not None # keep mypy happy
|
217
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
218
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
219
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
220
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
221
|
-
|
222
231
|
# Specify input columns so column pruning will be enforced
|
223
232
|
selected_cols = self._get_active_columns()
|
224
233
|
if len(selected_cols) > 0:
|
@@ -246,7 +255,9 @@ class SelectKBest(BaseTransformer):
|
|
246
255
|
label_cols=self.label_cols,
|
247
256
|
sample_weight_col=self.sample_weight_col,
|
248
257
|
autogenerated=self._autogenerated,
|
249
|
-
subproject=_SUBPROJECT
|
258
|
+
subproject=_SUBPROJECT,
|
259
|
+
use_external_memory_version=self._use_external_memory_version,
|
260
|
+
batch_size=self._batch_size,
|
250
261
|
)
|
251
262
|
self._sklearn_object = model_trainer.train()
|
252
263
|
self._is_fitted = True
|
@@ -517,6 +528,22 @@ class SelectKBest(BaseTransformer):
|
|
517
528
|
# each row containing a list of values.
|
518
529
|
expected_dtype = "ARRAY"
|
519
530
|
|
531
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
532
|
+
if expected_dtype == "":
|
533
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
534
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
535
|
+
expected_dtype = "ARRAY"
|
536
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
537
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
538
|
+
expected_dtype = "ARRAY"
|
539
|
+
else:
|
540
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
541
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
542
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
543
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
544
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
545
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
546
|
+
|
520
547
|
output_df = self._batch_inference(
|
521
548
|
dataset=dataset,
|
522
549
|
inference_method="transform",
|
@@ -532,8 +559,8 @@ class SelectKBest(BaseTransformer):
|
|
532
559
|
|
533
560
|
return output_df
|
534
561
|
|
535
|
-
@available_if(
|
536
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
562
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
563
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
537
564
|
""" Method not supported for this class.
|
538
565
|
|
539
566
|
|
@@ -546,13 +573,21 @@ class SelectKBest(BaseTransformer):
|
|
546
573
|
Returns:
|
547
574
|
Predicted dataset.
|
548
575
|
"""
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
576
|
+
self.fit(dataset)
|
577
|
+
assert self._sklearn_object is not None
|
578
|
+
return self._sklearn_object.labels_
|
579
|
+
|
580
|
+
|
581
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
582
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
583
|
+
"""
|
584
|
+
Returns:
|
585
|
+
Transformed dataset.
|
586
|
+
"""
|
587
|
+
self.fit(dataset)
|
588
|
+
assert self._sklearn_object is not None
|
589
|
+
return self._sklearn_object.embedding_
|
590
|
+
|
556
591
|
|
557
592
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
558
593
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|