snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GradientBoostingClassifier(BaseTransformer):
|
58
70
|
r"""Gradient Boosting for classification
|
59
71
|
For more details on this class, see [sklearn.ensemble.GradientBoostingClassifier]
|
@@ -304,7 +316,9 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
304
316
|
self.set_label_cols(label_cols)
|
305
317
|
self.set_passthrough_cols(passthrough_cols)
|
306
318
|
self.set_drop_input_cols(drop_input_cols)
|
307
|
-
self.set_sample_weight_col(sample_weight_col)
|
319
|
+
self.set_sample_weight_col(sample_weight_col)
|
320
|
+
self._use_external_memory_version = False
|
321
|
+
self._batch_size = -1
|
308
322
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
309
323
|
|
310
324
|
self._deps = list(deps)
|
@@ -399,11 +413,6 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
399
413
|
if isinstance(dataset, DataFrame):
|
400
414
|
session = dataset._session
|
401
415
|
assert session is not None # keep mypy happy
|
402
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
403
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
404
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
405
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
406
|
-
|
407
416
|
# Specify input columns so column pruning will be enforced
|
408
417
|
selected_cols = self._get_active_columns()
|
409
418
|
if len(selected_cols) > 0:
|
@@ -431,7 +440,9 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
431
440
|
label_cols=self.label_cols,
|
432
441
|
sample_weight_col=self.sample_weight_col,
|
433
442
|
autogenerated=self._autogenerated,
|
434
|
-
subproject=_SUBPROJECT
|
443
|
+
subproject=_SUBPROJECT,
|
444
|
+
use_external_memory_version=self._use_external_memory_version,
|
445
|
+
batch_size=self._batch_size,
|
435
446
|
)
|
436
447
|
self._sklearn_object = model_trainer.train()
|
437
448
|
self._is_fitted = True
|
@@ -702,6 +713,22 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
702
713
|
# each row containing a list of values.
|
703
714
|
expected_dtype = "ARRAY"
|
704
715
|
|
716
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
717
|
+
if expected_dtype == "":
|
718
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
719
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
720
|
+
expected_dtype = "ARRAY"
|
721
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
722
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
723
|
+
expected_dtype = "ARRAY"
|
724
|
+
else:
|
725
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
726
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
727
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
728
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
729
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
730
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
731
|
+
|
705
732
|
output_df = self._batch_inference(
|
706
733
|
dataset=dataset,
|
707
734
|
inference_method="transform",
|
@@ -717,8 +744,8 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
717
744
|
|
718
745
|
return output_df
|
719
746
|
|
720
|
-
@available_if(
|
721
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
747
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
748
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
722
749
|
""" Method not supported for this class.
|
723
750
|
|
724
751
|
|
@@ -731,13 +758,21 @@ class GradientBoostingClassifier(BaseTransformer):
|
|
731
758
|
Returns:
|
732
759
|
Predicted dataset.
|
733
760
|
"""
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
761
|
+
self.fit(dataset)
|
762
|
+
assert self._sklearn_object is not None
|
763
|
+
return self._sklearn_object.labels_
|
764
|
+
|
765
|
+
|
766
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
767
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
768
|
+
"""
|
769
|
+
Returns:
|
770
|
+
Transformed dataset.
|
771
|
+
"""
|
772
|
+
self.fit(dataset)
|
773
|
+
assert self._sklearn_object is not None
|
774
|
+
return self._sklearn_object.embedding_
|
775
|
+
|
741
776
|
|
742
777
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
743
778
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GradientBoostingRegressor(BaseTransformer):
|
58
70
|
r"""Gradient Boosting for regression
|
59
71
|
For more details on this class, see [sklearn.ensemble.GradientBoostingRegressor]
|
@@ -312,7 +324,9 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
312
324
|
self.set_label_cols(label_cols)
|
313
325
|
self.set_passthrough_cols(passthrough_cols)
|
314
326
|
self.set_drop_input_cols(drop_input_cols)
|
315
|
-
self.set_sample_weight_col(sample_weight_col)
|
327
|
+
self.set_sample_weight_col(sample_weight_col)
|
328
|
+
self._use_external_memory_version = False
|
329
|
+
self._batch_size = -1
|
316
330
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
317
331
|
|
318
332
|
self._deps = list(deps)
|
@@ -408,11 +422,6 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
408
422
|
if isinstance(dataset, DataFrame):
|
409
423
|
session = dataset._session
|
410
424
|
assert session is not None # keep mypy happy
|
411
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
412
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
413
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
414
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
415
|
-
|
416
425
|
# Specify input columns so column pruning will be enforced
|
417
426
|
selected_cols = self._get_active_columns()
|
418
427
|
if len(selected_cols) > 0:
|
@@ -440,7 +449,9 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
440
449
|
label_cols=self.label_cols,
|
441
450
|
sample_weight_col=self.sample_weight_col,
|
442
451
|
autogenerated=self._autogenerated,
|
443
|
-
subproject=_SUBPROJECT
|
452
|
+
subproject=_SUBPROJECT,
|
453
|
+
use_external_memory_version=self._use_external_memory_version,
|
454
|
+
batch_size=self._batch_size,
|
444
455
|
)
|
445
456
|
self._sklearn_object = model_trainer.train()
|
446
457
|
self._is_fitted = True
|
@@ -711,6 +722,22 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
711
722
|
# each row containing a list of values.
|
712
723
|
expected_dtype = "ARRAY"
|
713
724
|
|
725
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
726
|
+
if expected_dtype == "":
|
727
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
728
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
729
|
+
expected_dtype = "ARRAY"
|
730
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
731
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
732
|
+
expected_dtype = "ARRAY"
|
733
|
+
else:
|
734
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
735
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
736
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
737
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
738
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
739
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
740
|
+
|
714
741
|
output_df = self._batch_inference(
|
715
742
|
dataset=dataset,
|
716
743
|
inference_method="transform",
|
@@ -726,8 +753,8 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
726
753
|
|
727
754
|
return output_df
|
728
755
|
|
729
|
-
@available_if(
|
730
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
756
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
757
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
731
758
|
""" Method not supported for this class.
|
732
759
|
|
733
760
|
|
@@ -740,13 +767,21 @@ class GradientBoostingRegressor(BaseTransformer):
|
|
740
767
|
Returns:
|
741
768
|
Predicted dataset.
|
742
769
|
"""
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
770
|
+
self.fit(dataset)
|
771
|
+
assert self._sklearn_object is not None
|
772
|
+
return self._sklearn_object.labels_
|
773
|
+
|
774
|
+
|
775
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
776
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
777
|
+
"""
|
778
|
+
Returns:
|
779
|
+
Transformed dataset.
|
780
|
+
"""
|
781
|
+
self.fit(dataset)
|
782
|
+
assert self._sklearn_object is not None
|
783
|
+
return self._sklearn_object.embedding_
|
784
|
+
|
750
785
|
|
751
786
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
752
787
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class HistGradientBoostingClassifier(BaseTransformer):
|
58
70
|
r"""Histogram-based Gradient Boosting Classification Tree
|
59
71
|
For more details on this class, see [sklearn.ensemble.HistGradientBoostingClassifier]
|
@@ -285,7 +297,9 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
285
297
|
self.set_label_cols(label_cols)
|
286
298
|
self.set_passthrough_cols(passthrough_cols)
|
287
299
|
self.set_drop_input_cols(drop_input_cols)
|
288
|
-
self.set_sample_weight_col(sample_weight_col)
|
300
|
+
self.set_sample_weight_col(sample_weight_col)
|
301
|
+
self._use_external_memory_version = False
|
302
|
+
self._batch_size = -1
|
289
303
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
290
304
|
|
291
305
|
self._deps = list(deps)
|
@@ -380,11 +394,6 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
380
394
|
if isinstance(dataset, DataFrame):
|
381
395
|
session = dataset._session
|
382
396
|
assert session is not None # keep mypy happy
|
383
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
384
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
385
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
386
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
387
|
-
|
388
397
|
# Specify input columns so column pruning will be enforced
|
389
398
|
selected_cols = self._get_active_columns()
|
390
399
|
if len(selected_cols) > 0:
|
@@ -412,7 +421,9 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
412
421
|
label_cols=self.label_cols,
|
413
422
|
sample_weight_col=self.sample_weight_col,
|
414
423
|
autogenerated=self._autogenerated,
|
415
|
-
subproject=_SUBPROJECT
|
424
|
+
subproject=_SUBPROJECT,
|
425
|
+
use_external_memory_version=self._use_external_memory_version,
|
426
|
+
batch_size=self._batch_size,
|
416
427
|
)
|
417
428
|
self._sklearn_object = model_trainer.train()
|
418
429
|
self._is_fitted = True
|
@@ -683,6 +694,22 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
683
694
|
# each row containing a list of values.
|
684
695
|
expected_dtype = "ARRAY"
|
685
696
|
|
697
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
698
|
+
if expected_dtype == "":
|
699
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
700
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
701
|
+
expected_dtype = "ARRAY"
|
702
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
703
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
704
|
+
expected_dtype = "ARRAY"
|
705
|
+
else:
|
706
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
707
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
708
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
709
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
710
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
711
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
712
|
+
|
686
713
|
output_df = self._batch_inference(
|
687
714
|
dataset=dataset,
|
688
715
|
inference_method="transform",
|
@@ -698,8 +725,8 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
698
725
|
|
699
726
|
return output_df
|
700
727
|
|
701
|
-
@available_if(
|
702
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
728
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
729
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
703
730
|
""" Method not supported for this class.
|
704
731
|
|
705
732
|
|
@@ -712,13 +739,21 @@ class HistGradientBoostingClassifier(BaseTransformer):
|
|
712
739
|
Returns:
|
713
740
|
Predicted dataset.
|
714
741
|
"""
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
742
|
+
self.fit(dataset)
|
743
|
+
assert self._sklearn_object is not None
|
744
|
+
return self._sklearn_object.labels_
|
745
|
+
|
746
|
+
|
747
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
748
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
749
|
+
"""
|
750
|
+
Returns:
|
751
|
+
Transformed dataset.
|
752
|
+
"""
|
753
|
+
self.fit(dataset)
|
754
|
+
assert self._sklearn_object is not None
|
755
|
+
return self._sklearn_object.embedding_
|
756
|
+
|
722
757
|
|
723
758
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
724
759
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class HistGradientBoostingRegressor(BaseTransformer):
|
58
70
|
r"""Histogram-based Gradient Boosting Regression Tree
|
59
71
|
For more details on this class, see [sklearn.ensemble.HistGradientBoostingRegressor]
|
@@ -276,7 +288,9 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
276
288
|
self.set_label_cols(label_cols)
|
277
289
|
self.set_passthrough_cols(passthrough_cols)
|
278
290
|
self.set_drop_input_cols(drop_input_cols)
|
279
|
-
self.set_sample_weight_col(sample_weight_col)
|
291
|
+
self.set_sample_weight_col(sample_weight_col)
|
292
|
+
self._use_external_memory_version = False
|
293
|
+
self._batch_size = -1
|
280
294
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
281
295
|
|
282
296
|
self._deps = list(deps)
|
@@ -371,11 +385,6 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
371
385
|
if isinstance(dataset, DataFrame):
|
372
386
|
session = dataset._session
|
373
387
|
assert session is not None # keep mypy happy
|
374
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
375
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
376
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
377
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
378
|
-
|
379
388
|
# Specify input columns so column pruning will be enforced
|
380
389
|
selected_cols = self._get_active_columns()
|
381
390
|
if len(selected_cols) > 0:
|
@@ -403,7 +412,9 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
403
412
|
label_cols=self.label_cols,
|
404
413
|
sample_weight_col=self.sample_weight_col,
|
405
414
|
autogenerated=self._autogenerated,
|
406
|
-
subproject=_SUBPROJECT
|
415
|
+
subproject=_SUBPROJECT,
|
416
|
+
use_external_memory_version=self._use_external_memory_version,
|
417
|
+
batch_size=self._batch_size,
|
407
418
|
)
|
408
419
|
self._sklearn_object = model_trainer.train()
|
409
420
|
self._is_fitted = True
|
@@ -674,6 +685,22 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
674
685
|
# each row containing a list of values.
|
675
686
|
expected_dtype = "ARRAY"
|
676
687
|
|
688
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
689
|
+
if expected_dtype == "":
|
690
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
691
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
692
|
+
expected_dtype = "ARRAY"
|
693
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
694
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
695
|
+
expected_dtype = "ARRAY"
|
696
|
+
else:
|
697
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
698
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
699
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
700
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
701
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
702
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
703
|
+
|
677
704
|
output_df = self._batch_inference(
|
678
705
|
dataset=dataset,
|
679
706
|
inference_method="transform",
|
@@ -689,8 +716,8 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
689
716
|
|
690
717
|
return output_df
|
691
718
|
|
692
|
-
@available_if(
|
693
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
719
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
720
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
694
721
|
""" Method not supported for this class.
|
695
722
|
|
696
723
|
|
@@ -703,13 +730,21 @@ class HistGradientBoostingRegressor(BaseTransformer):
|
|
703
730
|
Returns:
|
704
731
|
Predicted dataset.
|
705
732
|
"""
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
733
|
+
self.fit(dataset)
|
734
|
+
assert self._sklearn_object is not None
|
735
|
+
return self._sklearn_object.labels_
|
736
|
+
|
737
|
+
|
738
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
739
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
740
|
+
"""
|
741
|
+
Returns:
|
742
|
+
Transformed dataset.
|
743
|
+
"""
|
744
|
+
self.fit(dataset)
|
745
|
+
assert self._sklearn_object is not None
|
746
|
+
return self._sklearn_object.embedding_
|
747
|
+
|
713
748
|
|
714
749
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
715
750
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class IsolationForest(BaseTransformer):
|
58
70
|
r"""Isolation Forest Algorithm
|
59
71
|
For more details on this class, see [sklearn.ensemble.IsolationForest]
|
@@ -187,7 +199,9 @@ class IsolationForest(BaseTransformer):
|
|
187
199
|
self.set_label_cols(label_cols)
|
188
200
|
self.set_passthrough_cols(passthrough_cols)
|
189
201
|
self.set_drop_input_cols(drop_input_cols)
|
190
|
-
self.set_sample_weight_col(sample_weight_col)
|
202
|
+
self.set_sample_weight_col(sample_weight_col)
|
203
|
+
self._use_external_memory_version = False
|
204
|
+
self._batch_size = -1
|
191
205
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
192
206
|
|
193
207
|
self._deps = list(deps)
|
@@ -271,11 +285,6 @@ class IsolationForest(BaseTransformer):
|
|
271
285
|
if isinstance(dataset, DataFrame):
|
272
286
|
session = dataset._session
|
273
287
|
assert session is not None # keep mypy happy
|
274
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
275
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
276
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
277
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
278
|
-
|
279
288
|
# Specify input columns so column pruning will be enforced
|
280
289
|
selected_cols = self._get_active_columns()
|
281
290
|
if len(selected_cols) > 0:
|
@@ -303,7 +312,9 @@ class IsolationForest(BaseTransformer):
|
|
303
312
|
label_cols=self.label_cols,
|
304
313
|
sample_weight_col=self.sample_weight_col,
|
305
314
|
autogenerated=self._autogenerated,
|
306
|
-
subproject=_SUBPROJECT
|
315
|
+
subproject=_SUBPROJECT,
|
316
|
+
use_external_memory_version=self._use_external_memory_version,
|
317
|
+
batch_size=self._batch_size,
|
307
318
|
)
|
308
319
|
self._sklearn_object = model_trainer.train()
|
309
320
|
self._is_fitted = True
|
@@ -574,6 +585,22 @@ class IsolationForest(BaseTransformer):
|
|
574
585
|
# each row containing a list of values.
|
575
586
|
expected_dtype = "ARRAY"
|
576
587
|
|
588
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
589
|
+
if expected_dtype == "":
|
590
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
591
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
592
|
+
expected_dtype = "ARRAY"
|
593
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
594
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
595
|
+
expected_dtype = "ARRAY"
|
596
|
+
else:
|
597
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
598
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
599
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
600
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
601
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
602
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
603
|
+
|
577
604
|
output_df = self._batch_inference(
|
578
605
|
dataset=dataset,
|
579
606
|
inference_method="transform",
|
@@ -589,8 +616,8 @@ class IsolationForest(BaseTransformer):
|
|
589
616
|
|
590
617
|
return output_df
|
591
618
|
|
592
|
-
@available_if(
|
593
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
619
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
620
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
594
621
|
""" Perform fit on X and returns labels for X
|
595
622
|
For more details on this function, see [sklearn.ensemble.IsolationForest.fit_predict]
|
596
623
|
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.IsolationForest.html#sklearn.ensemble.IsolationForest.fit_predict)
|
@@ -605,13 +632,21 @@ class IsolationForest(BaseTransformer):
|
|
605
632
|
Returns:
|
606
633
|
Predicted dataset.
|
607
634
|
"""
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
635
|
+
self.fit(dataset)
|
636
|
+
assert self._sklearn_object is not None
|
637
|
+
return self._sklearn_object.labels_
|
638
|
+
|
639
|
+
|
640
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
641
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
642
|
+
"""
|
643
|
+
Returns:
|
644
|
+
Transformed dataset.
|
645
|
+
"""
|
646
|
+
self.fit(dataset)
|
647
|
+
assert self._sklearn_object is not None
|
648
|
+
return self._sklearn_object.embedding_
|
649
|
+
|
615
650
|
|
616
651
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
617
652
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|