snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RandomForestClassifier(BaseTransformer):
|
58
70
|
r"""A random forest classifier
|
59
71
|
For more details on this class, see [sklearn.ensemble.RandomForestClassifier]
|
@@ -290,7 +302,9 @@ class RandomForestClassifier(BaseTransformer):
|
|
290
302
|
self.set_label_cols(label_cols)
|
291
303
|
self.set_passthrough_cols(passthrough_cols)
|
292
304
|
self.set_drop_input_cols(drop_input_cols)
|
293
|
-
self.set_sample_weight_col(sample_weight_col)
|
305
|
+
self.set_sample_weight_col(sample_weight_col)
|
306
|
+
self._use_external_memory_version = False
|
307
|
+
self._batch_size = -1
|
294
308
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
295
309
|
|
296
310
|
self._deps = list(deps)
|
@@ -383,11 +397,6 @@ class RandomForestClassifier(BaseTransformer):
|
|
383
397
|
if isinstance(dataset, DataFrame):
|
384
398
|
session = dataset._session
|
385
399
|
assert session is not None # keep mypy happy
|
386
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
387
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
388
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
389
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
390
|
-
|
391
400
|
# Specify input columns so column pruning will be enforced
|
392
401
|
selected_cols = self._get_active_columns()
|
393
402
|
if len(selected_cols) > 0:
|
@@ -415,7 +424,9 @@ class RandomForestClassifier(BaseTransformer):
|
|
415
424
|
label_cols=self.label_cols,
|
416
425
|
sample_weight_col=self.sample_weight_col,
|
417
426
|
autogenerated=self._autogenerated,
|
418
|
-
subproject=_SUBPROJECT
|
427
|
+
subproject=_SUBPROJECT,
|
428
|
+
use_external_memory_version=self._use_external_memory_version,
|
429
|
+
batch_size=self._batch_size,
|
419
430
|
)
|
420
431
|
self._sklearn_object = model_trainer.train()
|
421
432
|
self._is_fitted = True
|
@@ -686,6 +697,22 @@ class RandomForestClassifier(BaseTransformer):
|
|
686
697
|
# each row containing a list of values.
|
687
698
|
expected_dtype = "ARRAY"
|
688
699
|
|
700
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
701
|
+
if expected_dtype == "":
|
702
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
703
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
704
|
+
expected_dtype = "ARRAY"
|
705
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
706
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
707
|
+
expected_dtype = "ARRAY"
|
708
|
+
else:
|
709
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
710
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
711
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
712
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
713
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
714
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
715
|
+
|
689
716
|
output_df = self._batch_inference(
|
690
717
|
dataset=dataset,
|
691
718
|
inference_method="transform",
|
@@ -701,8 +728,8 @@ class RandomForestClassifier(BaseTransformer):
|
|
701
728
|
|
702
729
|
return output_df
|
703
730
|
|
704
|
-
@available_if(
|
705
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
731
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
732
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
706
733
|
""" Method not supported for this class.
|
707
734
|
|
708
735
|
|
@@ -715,13 +742,21 @@ class RandomForestClassifier(BaseTransformer):
|
|
715
742
|
Returns:
|
716
743
|
Predicted dataset.
|
717
744
|
"""
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
745
|
+
self.fit(dataset)
|
746
|
+
assert self._sklearn_object is not None
|
747
|
+
return self._sklearn_object.labels_
|
748
|
+
|
749
|
+
|
750
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
751
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
752
|
+
"""
|
753
|
+
Returns:
|
754
|
+
Transformed dataset.
|
755
|
+
"""
|
756
|
+
self.fit(dataset)
|
757
|
+
assert self._sklearn_object is not None
|
758
|
+
return self._sklearn_object.embedding_
|
759
|
+
|
725
760
|
|
726
761
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
727
762
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RandomForestRegressor(BaseTransformer):
|
58
70
|
r"""A random forest regressor
|
59
71
|
For more details on this class, see [sklearn.ensemble.RandomForestRegressor]
|
@@ -270,7 +282,9 @@ class RandomForestRegressor(BaseTransformer):
|
|
270
282
|
self.set_label_cols(label_cols)
|
271
283
|
self.set_passthrough_cols(passthrough_cols)
|
272
284
|
self.set_drop_input_cols(drop_input_cols)
|
273
|
-
self.set_sample_weight_col(sample_weight_col)
|
285
|
+
self.set_sample_weight_col(sample_weight_col)
|
286
|
+
self._use_external_memory_version = False
|
287
|
+
self._batch_size = -1
|
274
288
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
275
289
|
|
276
290
|
self._deps = list(deps)
|
@@ -362,11 +376,6 @@ class RandomForestRegressor(BaseTransformer):
|
|
362
376
|
if isinstance(dataset, DataFrame):
|
363
377
|
session = dataset._session
|
364
378
|
assert session is not None # keep mypy happy
|
365
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
366
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
367
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
368
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
369
|
-
|
370
379
|
# Specify input columns so column pruning will be enforced
|
371
380
|
selected_cols = self._get_active_columns()
|
372
381
|
if len(selected_cols) > 0:
|
@@ -394,7 +403,9 @@ class RandomForestRegressor(BaseTransformer):
|
|
394
403
|
label_cols=self.label_cols,
|
395
404
|
sample_weight_col=self.sample_weight_col,
|
396
405
|
autogenerated=self._autogenerated,
|
397
|
-
subproject=_SUBPROJECT
|
406
|
+
subproject=_SUBPROJECT,
|
407
|
+
use_external_memory_version=self._use_external_memory_version,
|
408
|
+
batch_size=self._batch_size,
|
398
409
|
)
|
399
410
|
self._sklearn_object = model_trainer.train()
|
400
411
|
self._is_fitted = True
|
@@ -665,6 +676,22 @@ class RandomForestRegressor(BaseTransformer):
|
|
665
676
|
# each row containing a list of values.
|
666
677
|
expected_dtype = "ARRAY"
|
667
678
|
|
679
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
680
|
+
if expected_dtype == "":
|
681
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
682
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
683
|
+
expected_dtype = "ARRAY"
|
684
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
685
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
686
|
+
expected_dtype = "ARRAY"
|
687
|
+
else:
|
688
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
689
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
690
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
691
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
692
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
693
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
694
|
+
|
668
695
|
output_df = self._batch_inference(
|
669
696
|
dataset=dataset,
|
670
697
|
inference_method="transform",
|
@@ -680,8 +707,8 @@ class RandomForestRegressor(BaseTransformer):
|
|
680
707
|
|
681
708
|
return output_df
|
682
709
|
|
683
|
-
@available_if(
|
684
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
710
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
711
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
685
712
|
""" Method not supported for this class.
|
686
713
|
|
687
714
|
|
@@ -694,13 +721,21 @@ class RandomForestRegressor(BaseTransformer):
|
|
694
721
|
Returns:
|
695
722
|
Predicted dataset.
|
696
723
|
"""
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
724
|
+
self.fit(dataset)
|
725
|
+
assert self._sklearn_object is not None
|
726
|
+
return self._sklearn_object.labels_
|
727
|
+
|
728
|
+
|
729
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
730
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
731
|
+
"""
|
732
|
+
Returns:
|
733
|
+
Transformed dataset.
|
734
|
+
"""
|
735
|
+
self.fit(dataset)
|
736
|
+
assert self._sklearn_object is not None
|
737
|
+
return self._sklearn_object.embedding_
|
738
|
+
|
704
739
|
|
705
740
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
706
741
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class StackingRegressor(BaseTransformer):
|
58
70
|
r"""Stack of estimators with a final regressor
|
59
71
|
For more details on this class, see [sklearn.ensemble.StackingRegressor]
|
@@ -180,7 +192,9 @@ class StackingRegressor(BaseTransformer):
|
|
180
192
|
self.set_label_cols(label_cols)
|
181
193
|
self.set_passthrough_cols(passthrough_cols)
|
182
194
|
self.set_drop_input_cols(drop_input_cols)
|
183
|
-
self.set_sample_weight_col(sample_weight_col)
|
195
|
+
self.set_sample_weight_col(sample_weight_col)
|
196
|
+
self._use_external_memory_version = False
|
197
|
+
self._batch_size = -1
|
184
198
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
185
199
|
deps = deps | gather_dependencies(estimators)
|
186
200
|
deps = deps | gather_dependencies(final_estimator)
|
@@ -263,11 +277,6 @@ class StackingRegressor(BaseTransformer):
|
|
263
277
|
if isinstance(dataset, DataFrame):
|
264
278
|
session = dataset._session
|
265
279
|
assert session is not None # keep mypy happy
|
266
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
267
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
268
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
269
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
270
|
-
|
271
280
|
# Specify input columns so column pruning will be enforced
|
272
281
|
selected_cols = self._get_active_columns()
|
273
282
|
if len(selected_cols) > 0:
|
@@ -295,7 +304,9 @@ class StackingRegressor(BaseTransformer):
|
|
295
304
|
label_cols=self.label_cols,
|
296
305
|
sample_weight_col=self.sample_weight_col,
|
297
306
|
autogenerated=self._autogenerated,
|
298
|
-
subproject=_SUBPROJECT
|
307
|
+
subproject=_SUBPROJECT,
|
308
|
+
use_external_memory_version=self._use_external_memory_version,
|
309
|
+
batch_size=self._batch_size,
|
299
310
|
)
|
300
311
|
self._sklearn_object = model_trainer.train()
|
301
312
|
self._is_fitted = True
|
@@ -568,6 +579,22 @@ class StackingRegressor(BaseTransformer):
|
|
568
579
|
# each row containing a list of values.
|
569
580
|
expected_dtype = "ARRAY"
|
570
581
|
|
582
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
583
|
+
if expected_dtype == "":
|
584
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
585
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
586
|
+
expected_dtype = "ARRAY"
|
587
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
588
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
589
|
+
expected_dtype = "ARRAY"
|
590
|
+
else:
|
591
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
592
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
593
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
594
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
595
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
596
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
597
|
+
|
571
598
|
output_df = self._batch_inference(
|
572
599
|
dataset=dataset,
|
573
600
|
inference_method="transform",
|
@@ -583,8 +610,8 @@ class StackingRegressor(BaseTransformer):
|
|
583
610
|
|
584
611
|
return output_df
|
585
612
|
|
586
|
-
@available_if(
|
587
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
613
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
614
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
588
615
|
""" Method not supported for this class.
|
589
616
|
|
590
617
|
|
@@ -597,13 +624,21 @@ class StackingRegressor(BaseTransformer):
|
|
597
624
|
Returns:
|
598
625
|
Predicted dataset.
|
599
626
|
"""
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
627
|
+
self.fit(dataset)
|
628
|
+
assert self._sklearn_object is not None
|
629
|
+
return self._sklearn_object.labels_
|
630
|
+
|
631
|
+
|
632
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
633
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
634
|
+
"""
|
635
|
+
Returns:
|
636
|
+
Transformed dataset.
|
637
|
+
"""
|
638
|
+
self.fit(dataset)
|
639
|
+
assert self._sklearn_object is not None
|
640
|
+
return self._sklearn_object.embedding_
|
641
|
+
|
607
642
|
|
608
643
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
609
644
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class VotingClassifier(BaseTransformer):
|
58
70
|
r"""Soft Voting/Majority Rule classifier for unfitted estimators
|
59
71
|
For more details on this class, see [sklearn.ensemble.VotingClassifier]
|
@@ -164,7 +176,9 @@ class VotingClassifier(BaseTransformer):
|
|
164
176
|
self.set_label_cols(label_cols)
|
165
177
|
self.set_passthrough_cols(passthrough_cols)
|
166
178
|
self.set_drop_input_cols(drop_input_cols)
|
167
|
-
self.set_sample_weight_col(sample_weight_col)
|
179
|
+
self.set_sample_weight_col(sample_weight_col)
|
180
|
+
self._use_external_memory_version = False
|
181
|
+
self._batch_size = -1
|
168
182
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
169
183
|
deps = deps | gather_dependencies(estimators)
|
170
184
|
self._deps = list(deps)
|
@@ -245,11 +259,6 @@ class VotingClassifier(BaseTransformer):
|
|
245
259
|
if isinstance(dataset, DataFrame):
|
246
260
|
session = dataset._session
|
247
261
|
assert session is not None # keep mypy happy
|
248
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
249
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
250
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
251
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
252
|
-
|
253
262
|
# Specify input columns so column pruning will be enforced
|
254
263
|
selected_cols = self._get_active_columns()
|
255
264
|
if len(selected_cols) > 0:
|
@@ -277,7 +286,9 @@ class VotingClassifier(BaseTransformer):
|
|
277
286
|
label_cols=self.label_cols,
|
278
287
|
sample_weight_col=self.sample_weight_col,
|
279
288
|
autogenerated=self._autogenerated,
|
280
|
-
subproject=_SUBPROJECT
|
289
|
+
subproject=_SUBPROJECT,
|
290
|
+
use_external_memory_version=self._use_external_memory_version,
|
291
|
+
batch_size=self._batch_size,
|
281
292
|
)
|
282
293
|
self._sklearn_object = model_trainer.train()
|
283
294
|
self._is_fitted = True
|
@@ -550,6 +561,22 @@ class VotingClassifier(BaseTransformer):
|
|
550
561
|
# each row containing a list of values.
|
551
562
|
expected_dtype = "ARRAY"
|
552
563
|
|
564
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
565
|
+
if expected_dtype == "":
|
566
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
567
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
568
|
+
expected_dtype = "ARRAY"
|
569
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
570
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
571
|
+
expected_dtype = "ARRAY"
|
572
|
+
else:
|
573
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
574
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
575
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
576
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
577
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
578
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
579
|
+
|
553
580
|
output_df = self._batch_inference(
|
554
581
|
dataset=dataset,
|
555
582
|
inference_method="transform",
|
@@ -565,8 +592,8 @@ class VotingClassifier(BaseTransformer):
|
|
565
592
|
|
566
593
|
return output_df
|
567
594
|
|
568
|
-
@available_if(
|
569
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
595
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
596
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
570
597
|
""" Method not supported for this class.
|
571
598
|
|
572
599
|
|
@@ -579,13 +606,21 @@ class VotingClassifier(BaseTransformer):
|
|
579
606
|
Returns:
|
580
607
|
Predicted dataset.
|
581
608
|
"""
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
609
|
+
self.fit(dataset)
|
610
|
+
assert self._sklearn_object is not None
|
611
|
+
return self._sklearn_object.labels_
|
612
|
+
|
613
|
+
|
614
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
615
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
616
|
+
"""
|
617
|
+
Returns:
|
618
|
+
Transformed dataset.
|
619
|
+
"""
|
620
|
+
self.fit(dataset)
|
621
|
+
assert self._sklearn_object is not None
|
622
|
+
return self._sklearn_object.embedding_
|
623
|
+
|
589
624
|
|
590
625
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
591
626
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class VotingRegressor(BaseTransformer):
|
58
70
|
r"""Prediction voting regressor for unfitted estimators
|
59
71
|
For more details on this class, see [sklearn.ensemble.VotingRegressor]
|
@@ -148,7 +160,9 @@ class VotingRegressor(BaseTransformer):
|
|
148
160
|
self.set_label_cols(label_cols)
|
149
161
|
self.set_passthrough_cols(passthrough_cols)
|
150
162
|
self.set_drop_input_cols(drop_input_cols)
|
151
|
-
self.set_sample_weight_col(sample_weight_col)
|
163
|
+
self.set_sample_weight_col(sample_weight_col)
|
164
|
+
self._use_external_memory_version = False
|
165
|
+
self._batch_size = -1
|
152
166
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
153
167
|
deps = deps | gather_dependencies(estimators)
|
154
168
|
self._deps = list(deps)
|
@@ -227,11 +241,6 @@ class VotingRegressor(BaseTransformer):
|
|
227
241
|
if isinstance(dataset, DataFrame):
|
228
242
|
session = dataset._session
|
229
243
|
assert session is not None # keep mypy happy
|
230
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
231
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
232
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
233
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
234
|
-
|
235
244
|
# Specify input columns so column pruning will be enforced
|
236
245
|
selected_cols = self._get_active_columns()
|
237
246
|
if len(selected_cols) > 0:
|
@@ -259,7 +268,9 @@ class VotingRegressor(BaseTransformer):
|
|
259
268
|
label_cols=self.label_cols,
|
260
269
|
sample_weight_col=self.sample_weight_col,
|
261
270
|
autogenerated=self._autogenerated,
|
262
|
-
subproject=_SUBPROJECT
|
271
|
+
subproject=_SUBPROJECT,
|
272
|
+
use_external_memory_version=self._use_external_memory_version,
|
273
|
+
batch_size=self._batch_size,
|
263
274
|
)
|
264
275
|
self._sklearn_object = model_trainer.train()
|
265
276
|
self._is_fitted = True
|
@@ -532,6 +543,22 @@ class VotingRegressor(BaseTransformer):
|
|
532
543
|
# each row containing a list of values.
|
533
544
|
expected_dtype = "ARRAY"
|
534
545
|
|
546
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
547
|
+
if expected_dtype == "":
|
548
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
549
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
550
|
+
expected_dtype = "ARRAY"
|
551
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
552
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
553
|
+
expected_dtype = "ARRAY"
|
554
|
+
else:
|
555
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
556
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
557
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
558
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
559
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
560
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
561
|
+
|
535
562
|
output_df = self._batch_inference(
|
536
563
|
dataset=dataset,
|
537
564
|
inference_method="transform",
|
@@ -547,8 +574,8 @@ class VotingRegressor(BaseTransformer):
|
|
547
574
|
|
548
575
|
return output_df
|
549
576
|
|
550
|
-
@available_if(
|
551
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
577
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
578
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
552
579
|
""" Method not supported for this class.
|
553
580
|
|
554
581
|
|
@@ -561,13 +588,21 @@ class VotingRegressor(BaseTransformer):
|
|
561
588
|
Returns:
|
562
589
|
Predicted dataset.
|
563
590
|
"""
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
591
|
+
self.fit(dataset)
|
592
|
+
assert self._sklearn_object is not None
|
593
|
+
return self._sklearn_object.labels_
|
594
|
+
|
595
|
+
|
596
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
597
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
598
|
+
"""
|
599
|
+
Returns:
|
600
|
+
Transformed dataset.
|
601
|
+
"""
|
602
|
+
self.fit(dataset)
|
603
|
+
assert self._sklearn_object is not None
|
604
|
+
return self._sklearn_object.embedding_
|
605
|
+
|
571
606
|
|
572
607
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
573
608
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|