snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class AdaBoostRegressor(BaseTransformer):
|
58
70
|
r"""An AdaBoost regressor
|
59
71
|
For more details on this class, see [sklearn.ensemble.AdaBoostRegressor]
|
@@ -166,7 +178,9 @@ class AdaBoostRegressor(BaseTransformer):
|
|
166
178
|
self.set_label_cols(label_cols)
|
167
179
|
self.set_passthrough_cols(passthrough_cols)
|
168
180
|
self.set_drop_input_cols(drop_input_cols)
|
169
|
-
self.set_sample_weight_col(sample_weight_col)
|
181
|
+
self.set_sample_weight_col(sample_weight_col)
|
182
|
+
self._use_external_memory_version = False
|
183
|
+
self._batch_size = -1
|
170
184
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
171
185
|
deps = deps | gather_dependencies(estimator)
|
172
186
|
deps = deps | gather_dependencies(base_estimator)
|
@@ -249,11 +263,6 @@ class AdaBoostRegressor(BaseTransformer):
|
|
249
263
|
if isinstance(dataset, DataFrame):
|
250
264
|
session = dataset._session
|
251
265
|
assert session is not None # keep mypy happy
|
252
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
253
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
254
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
255
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
256
|
-
|
257
266
|
# Specify input columns so column pruning will be enforced
|
258
267
|
selected_cols = self._get_active_columns()
|
259
268
|
if len(selected_cols) > 0:
|
@@ -281,7 +290,9 @@ class AdaBoostRegressor(BaseTransformer):
|
|
281
290
|
label_cols=self.label_cols,
|
282
291
|
sample_weight_col=self.sample_weight_col,
|
283
292
|
autogenerated=self._autogenerated,
|
284
|
-
subproject=_SUBPROJECT
|
293
|
+
subproject=_SUBPROJECT,
|
294
|
+
use_external_memory_version=self._use_external_memory_version,
|
295
|
+
batch_size=self._batch_size,
|
285
296
|
)
|
286
297
|
self._sklearn_object = model_trainer.train()
|
287
298
|
self._is_fitted = True
|
@@ -552,6 +563,22 @@ class AdaBoostRegressor(BaseTransformer):
|
|
552
563
|
# each row containing a list of values.
|
553
564
|
expected_dtype = "ARRAY"
|
554
565
|
|
566
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
567
|
+
if expected_dtype == "":
|
568
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
569
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
570
|
+
expected_dtype = "ARRAY"
|
571
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
572
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
573
|
+
expected_dtype = "ARRAY"
|
574
|
+
else:
|
575
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
576
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
577
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
578
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
579
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
580
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
581
|
+
|
555
582
|
output_df = self._batch_inference(
|
556
583
|
dataset=dataset,
|
557
584
|
inference_method="transform",
|
@@ -567,8 +594,8 @@ class AdaBoostRegressor(BaseTransformer):
|
|
567
594
|
|
568
595
|
return output_df
|
569
596
|
|
570
|
-
@available_if(
|
571
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
597
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
598
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
572
599
|
""" Method not supported for this class.
|
573
600
|
|
574
601
|
|
@@ -581,13 +608,21 @@ class AdaBoostRegressor(BaseTransformer):
|
|
581
608
|
Returns:
|
582
609
|
Predicted dataset.
|
583
610
|
"""
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
611
|
+
self.fit(dataset)
|
612
|
+
assert self._sklearn_object is not None
|
613
|
+
return self._sklearn_object.labels_
|
614
|
+
|
615
|
+
|
616
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
617
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
618
|
+
"""
|
619
|
+
Returns:
|
620
|
+
Transformed dataset.
|
621
|
+
"""
|
622
|
+
self.fit(dataset)
|
623
|
+
assert self._sklearn_object is not None
|
624
|
+
return self._sklearn_object.embedding_
|
625
|
+
|
591
626
|
|
592
627
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
593
628
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class BaggingClassifier(BaseTransformer):
|
58
70
|
r"""A Bagging classifier
|
59
71
|
For more details on this class, see [sklearn.ensemble.BaggingClassifier]
|
@@ -195,7 +207,9 @@ class BaggingClassifier(BaseTransformer):
|
|
195
207
|
self.set_label_cols(label_cols)
|
196
208
|
self.set_passthrough_cols(passthrough_cols)
|
197
209
|
self.set_drop_input_cols(drop_input_cols)
|
198
|
-
self.set_sample_weight_col(sample_weight_col)
|
210
|
+
self.set_sample_weight_col(sample_weight_col)
|
211
|
+
self._use_external_memory_version = False
|
212
|
+
self._batch_size = -1
|
199
213
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
200
214
|
deps = deps | gather_dependencies(estimator)
|
201
215
|
deps = deps | gather_dependencies(base_estimator)
|
@@ -284,11 +298,6 @@ class BaggingClassifier(BaseTransformer):
|
|
284
298
|
if isinstance(dataset, DataFrame):
|
285
299
|
session = dataset._session
|
286
300
|
assert session is not None # keep mypy happy
|
287
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
288
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
289
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
290
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
291
|
-
|
292
301
|
# Specify input columns so column pruning will be enforced
|
293
302
|
selected_cols = self._get_active_columns()
|
294
303
|
if len(selected_cols) > 0:
|
@@ -316,7 +325,9 @@ class BaggingClassifier(BaseTransformer):
|
|
316
325
|
label_cols=self.label_cols,
|
317
326
|
sample_weight_col=self.sample_weight_col,
|
318
327
|
autogenerated=self._autogenerated,
|
319
|
-
subproject=_SUBPROJECT
|
328
|
+
subproject=_SUBPROJECT,
|
329
|
+
use_external_memory_version=self._use_external_memory_version,
|
330
|
+
batch_size=self._batch_size,
|
320
331
|
)
|
321
332
|
self._sklearn_object = model_trainer.train()
|
322
333
|
self._is_fitted = True
|
@@ -587,6 +598,22 @@ class BaggingClassifier(BaseTransformer):
|
|
587
598
|
# each row containing a list of values.
|
588
599
|
expected_dtype = "ARRAY"
|
589
600
|
|
601
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
602
|
+
if expected_dtype == "":
|
603
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
604
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
605
|
+
expected_dtype = "ARRAY"
|
606
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
607
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
608
|
+
expected_dtype = "ARRAY"
|
609
|
+
else:
|
610
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
611
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
612
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
613
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
614
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
615
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
616
|
+
|
590
617
|
output_df = self._batch_inference(
|
591
618
|
dataset=dataset,
|
592
619
|
inference_method="transform",
|
@@ -602,8 +629,8 @@ class BaggingClassifier(BaseTransformer):
|
|
602
629
|
|
603
630
|
return output_df
|
604
631
|
|
605
|
-
@available_if(
|
606
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
632
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
633
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
607
634
|
""" Method not supported for this class.
|
608
635
|
|
609
636
|
|
@@ -616,13 +643,21 @@ class BaggingClassifier(BaseTransformer):
|
|
616
643
|
Returns:
|
617
644
|
Predicted dataset.
|
618
645
|
"""
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
646
|
+
self.fit(dataset)
|
647
|
+
assert self._sklearn_object is not None
|
648
|
+
return self._sklearn_object.labels_
|
649
|
+
|
650
|
+
|
651
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
652
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
653
|
+
"""
|
654
|
+
Returns:
|
655
|
+
Transformed dataset.
|
656
|
+
"""
|
657
|
+
self.fit(dataset)
|
658
|
+
assert self._sklearn_object is not None
|
659
|
+
return self._sklearn_object.embedding_
|
660
|
+
|
626
661
|
|
627
662
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
628
663
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class BaggingRegressor(BaseTransformer):
|
58
70
|
r"""A Bagging regressor
|
59
71
|
For more details on this class, see [sklearn.ensemble.BaggingRegressor]
|
@@ -195,7 +207,9 @@ class BaggingRegressor(BaseTransformer):
|
|
195
207
|
self.set_label_cols(label_cols)
|
196
208
|
self.set_passthrough_cols(passthrough_cols)
|
197
209
|
self.set_drop_input_cols(drop_input_cols)
|
198
|
-
self.set_sample_weight_col(sample_weight_col)
|
210
|
+
self.set_sample_weight_col(sample_weight_col)
|
211
|
+
self._use_external_memory_version = False
|
212
|
+
self._batch_size = -1
|
199
213
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
200
214
|
deps = deps | gather_dependencies(estimator)
|
201
215
|
deps = deps | gather_dependencies(base_estimator)
|
@@ -284,11 +298,6 @@ class BaggingRegressor(BaseTransformer):
|
|
284
298
|
if isinstance(dataset, DataFrame):
|
285
299
|
session = dataset._session
|
286
300
|
assert session is not None # keep mypy happy
|
287
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
288
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
289
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
290
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
291
|
-
|
292
301
|
# Specify input columns so column pruning will be enforced
|
293
302
|
selected_cols = self._get_active_columns()
|
294
303
|
if len(selected_cols) > 0:
|
@@ -316,7 +325,9 @@ class BaggingRegressor(BaseTransformer):
|
|
316
325
|
label_cols=self.label_cols,
|
317
326
|
sample_weight_col=self.sample_weight_col,
|
318
327
|
autogenerated=self._autogenerated,
|
319
|
-
subproject=_SUBPROJECT
|
328
|
+
subproject=_SUBPROJECT,
|
329
|
+
use_external_memory_version=self._use_external_memory_version,
|
330
|
+
batch_size=self._batch_size,
|
320
331
|
)
|
321
332
|
self._sklearn_object = model_trainer.train()
|
322
333
|
self._is_fitted = True
|
@@ -587,6 +598,22 @@ class BaggingRegressor(BaseTransformer):
|
|
587
598
|
# each row containing a list of values.
|
588
599
|
expected_dtype = "ARRAY"
|
589
600
|
|
601
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
602
|
+
if expected_dtype == "":
|
603
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
604
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
605
|
+
expected_dtype = "ARRAY"
|
606
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
607
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
608
|
+
expected_dtype = "ARRAY"
|
609
|
+
else:
|
610
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
611
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
612
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
613
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
614
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
615
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
616
|
+
|
590
617
|
output_df = self._batch_inference(
|
591
618
|
dataset=dataset,
|
592
619
|
inference_method="transform",
|
@@ -602,8 +629,8 @@ class BaggingRegressor(BaseTransformer):
|
|
602
629
|
|
603
630
|
return output_df
|
604
631
|
|
605
|
-
@available_if(
|
606
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
632
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
633
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
607
634
|
""" Method not supported for this class.
|
608
635
|
|
609
636
|
|
@@ -616,13 +643,21 @@ class BaggingRegressor(BaseTransformer):
|
|
616
643
|
Returns:
|
617
644
|
Predicted dataset.
|
618
645
|
"""
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
646
|
+
self.fit(dataset)
|
647
|
+
assert self._sklearn_object is not None
|
648
|
+
return self._sklearn_object.labels_
|
649
|
+
|
650
|
+
|
651
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
652
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
653
|
+
"""
|
654
|
+
Returns:
|
655
|
+
Transformed dataset.
|
656
|
+
"""
|
657
|
+
self.fit(dataset)
|
658
|
+
assert self._sklearn_object is not None
|
659
|
+
return self._sklearn_object.embedding_
|
660
|
+
|
626
661
|
|
627
662
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
628
663
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ExtraTreesClassifier(BaseTransformer):
|
58
70
|
r"""An extra-trees classifier
|
59
71
|
For more details on this class, see [sklearn.ensemble.ExtraTreesClassifier]
|
@@ -294,7 +306,9 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
294
306
|
self.set_label_cols(label_cols)
|
295
307
|
self.set_passthrough_cols(passthrough_cols)
|
296
308
|
self.set_drop_input_cols(drop_input_cols)
|
297
|
-
self.set_sample_weight_col(sample_weight_col)
|
309
|
+
self.set_sample_weight_col(sample_weight_col)
|
310
|
+
self._use_external_memory_version = False
|
311
|
+
self._batch_size = -1
|
298
312
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
299
313
|
|
300
314
|
self._deps = list(deps)
|
@@ -387,11 +401,6 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
387
401
|
if isinstance(dataset, DataFrame):
|
388
402
|
session = dataset._session
|
389
403
|
assert session is not None # keep mypy happy
|
390
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
391
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
392
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
393
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
394
|
-
|
395
404
|
# Specify input columns so column pruning will be enforced
|
396
405
|
selected_cols = self._get_active_columns()
|
397
406
|
if len(selected_cols) > 0:
|
@@ -419,7 +428,9 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
419
428
|
label_cols=self.label_cols,
|
420
429
|
sample_weight_col=self.sample_weight_col,
|
421
430
|
autogenerated=self._autogenerated,
|
422
|
-
subproject=_SUBPROJECT
|
431
|
+
subproject=_SUBPROJECT,
|
432
|
+
use_external_memory_version=self._use_external_memory_version,
|
433
|
+
batch_size=self._batch_size,
|
423
434
|
)
|
424
435
|
self._sklearn_object = model_trainer.train()
|
425
436
|
self._is_fitted = True
|
@@ -690,6 +701,22 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
690
701
|
# each row containing a list of values.
|
691
702
|
expected_dtype = "ARRAY"
|
692
703
|
|
704
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
705
|
+
if expected_dtype == "":
|
706
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
707
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
708
|
+
expected_dtype = "ARRAY"
|
709
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
710
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
711
|
+
expected_dtype = "ARRAY"
|
712
|
+
else:
|
713
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
714
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
715
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
716
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
717
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
718
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
719
|
+
|
693
720
|
output_df = self._batch_inference(
|
694
721
|
dataset=dataset,
|
695
722
|
inference_method="transform",
|
@@ -705,8 +732,8 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
705
732
|
|
706
733
|
return output_df
|
707
734
|
|
708
|
-
@available_if(
|
709
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
735
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
736
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
710
737
|
""" Method not supported for this class.
|
711
738
|
|
712
739
|
|
@@ -719,13 +746,21 @@ class ExtraTreesClassifier(BaseTransformer):
|
|
719
746
|
Returns:
|
720
747
|
Predicted dataset.
|
721
748
|
"""
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
749
|
+
self.fit(dataset)
|
750
|
+
assert self._sklearn_object is not None
|
751
|
+
return self._sklearn_object.labels_
|
752
|
+
|
753
|
+
|
754
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
755
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
756
|
+
"""
|
757
|
+
Returns:
|
758
|
+
Transformed dataset.
|
759
|
+
"""
|
760
|
+
self.fit(dataset)
|
761
|
+
assert self._sklearn_object is not None
|
762
|
+
return self._sklearn_object.embedding_
|
763
|
+
|
729
764
|
|
730
765
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
731
766
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ExtraTreesRegressor(BaseTransformer):
|
58
70
|
r"""An extra-trees regressor
|
59
71
|
For more details on this class, see [sklearn.ensemble.ExtraTreesRegressor]
|
@@ -274,7 +286,9 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
274
286
|
self.set_label_cols(label_cols)
|
275
287
|
self.set_passthrough_cols(passthrough_cols)
|
276
288
|
self.set_drop_input_cols(drop_input_cols)
|
277
|
-
self.set_sample_weight_col(sample_weight_col)
|
289
|
+
self.set_sample_weight_col(sample_weight_col)
|
290
|
+
self._use_external_memory_version = False
|
291
|
+
self._batch_size = -1
|
278
292
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
279
293
|
|
280
294
|
self._deps = list(deps)
|
@@ -366,11 +380,6 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
366
380
|
if isinstance(dataset, DataFrame):
|
367
381
|
session = dataset._session
|
368
382
|
assert session is not None # keep mypy happy
|
369
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
370
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
371
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
372
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
373
|
-
|
374
383
|
# Specify input columns so column pruning will be enforced
|
375
384
|
selected_cols = self._get_active_columns()
|
376
385
|
if len(selected_cols) > 0:
|
@@ -398,7 +407,9 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
398
407
|
label_cols=self.label_cols,
|
399
408
|
sample_weight_col=self.sample_weight_col,
|
400
409
|
autogenerated=self._autogenerated,
|
401
|
-
subproject=_SUBPROJECT
|
410
|
+
subproject=_SUBPROJECT,
|
411
|
+
use_external_memory_version=self._use_external_memory_version,
|
412
|
+
batch_size=self._batch_size,
|
402
413
|
)
|
403
414
|
self._sklearn_object = model_trainer.train()
|
404
415
|
self._is_fitted = True
|
@@ -669,6 +680,22 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
669
680
|
# each row containing a list of values.
|
670
681
|
expected_dtype = "ARRAY"
|
671
682
|
|
683
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
684
|
+
if expected_dtype == "":
|
685
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
686
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
687
|
+
expected_dtype = "ARRAY"
|
688
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
689
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
690
|
+
expected_dtype = "ARRAY"
|
691
|
+
else:
|
692
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
693
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
694
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
695
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
696
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
697
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
698
|
+
|
672
699
|
output_df = self._batch_inference(
|
673
700
|
dataset=dataset,
|
674
701
|
inference_method="transform",
|
@@ -684,8 +711,8 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
684
711
|
|
685
712
|
return output_df
|
686
713
|
|
687
|
-
@available_if(
|
688
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
714
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
715
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
689
716
|
""" Method not supported for this class.
|
690
717
|
|
691
718
|
|
@@ -698,13 +725,21 @@ class ExtraTreesRegressor(BaseTransformer):
|
|
698
725
|
Returns:
|
699
726
|
Predicted dataset.
|
700
727
|
"""
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
728
|
+
self.fit(dataset)
|
729
|
+
assert self._sklearn_object is not None
|
730
|
+
return self._sklearn_object.labels_
|
731
|
+
|
732
|
+
|
733
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
734
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
735
|
+
"""
|
736
|
+
Returns:
|
737
|
+
Transformed dataset.
|
738
|
+
"""
|
739
|
+
self.fit(dataset)
|
740
|
+
assert self._sklearn_object is not None
|
741
|
+
return self._sklearn_object.embedding_
|
742
|
+
|
708
743
|
|
709
744
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
710
745
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|