snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return True and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SpectralClustering(BaseTransformer):
|
58
70
|
r"""Apply clustering to a projection of the normalized Laplacian
|
59
71
|
For more details on this class, see [sklearn.cluster.SpectralClustering]
|
@@ -237,7 +249,9 @@ class SpectralClustering(BaseTransformer):
|
|
237
249
|
self.set_label_cols(label_cols)
|
238
250
|
self.set_passthrough_cols(passthrough_cols)
|
239
251
|
self.set_drop_input_cols(drop_input_cols)
|
240
|
-
self.set_sample_weight_col(sample_weight_col)
|
252
|
+
self.set_sample_weight_col(sample_weight_col)
|
253
|
+
self._use_external_memory_version = False
|
254
|
+
self._batch_size = -1
|
241
255
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
242
256
|
|
243
257
|
self._deps = list(deps)
|
@@ -327,11 +341,6 @@ class SpectralClustering(BaseTransformer):
|
|
327
341
|
if isinstance(dataset, DataFrame):
|
328
342
|
session = dataset._session
|
329
343
|
assert session is not None # keep mypy happy
|
330
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
331
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
332
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
333
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
334
|
-
|
335
344
|
# Specify input columns so column pruning will be enforced
|
336
345
|
selected_cols = self._get_active_columns()
|
337
346
|
if len(selected_cols) > 0:
|
@@ -359,7 +368,9 @@ class SpectralClustering(BaseTransformer):
|
|
359
368
|
label_cols=self.label_cols,
|
360
369
|
sample_weight_col=self.sample_weight_col,
|
361
370
|
autogenerated=self._autogenerated,
|
362
|
-
subproject=_SUBPROJECT
|
371
|
+
subproject=_SUBPROJECT,
|
372
|
+
use_external_memory_version=self._use_external_memory_version,
|
373
|
+
batch_size=self._batch_size,
|
363
374
|
)
|
364
375
|
self._sklearn_object = model_trainer.train()
|
365
376
|
self._is_fitted = True
|
@@ -628,6 +639,22 @@ class SpectralClustering(BaseTransformer):
|
|
628
639
|
# each row containing a list of values.
|
629
640
|
expected_dtype = "ARRAY"
|
630
641
|
|
642
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
643
|
+
if expected_dtype == "":
|
644
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
645
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
646
|
+
expected_dtype = "ARRAY"
|
647
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
648
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
649
|
+
expected_dtype = "ARRAY"
|
650
|
+
else:
|
651
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
652
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
653
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
654
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
655
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
656
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
657
|
+
|
631
658
|
output_df = self._batch_inference(
|
632
659
|
dataset=dataset,
|
633
660
|
inference_method="transform",
|
@@ -643,8 +670,8 @@ class SpectralClustering(BaseTransformer):
|
|
643
670
|
|
644
671
|
return output_df
|
645
672
|
|
646
|
-
@available_if(
|
647
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
673
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
674
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
648
675
|
""" Perform spectral clustering on `X` and return cluster labels
|
649
676
|
For more details on this function, see [sklearn.cluster.SpectralClustering.fit_predict]
|
650
677
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.SpectralClustering.html#sklearn.cluster.SpectralClustering.fit_predict)
|
@@ -659,13 +686,21 @@ class SpectralClustering(BaseTransformer):
|
|
659
686
|
Returns:
|
660
687
|
Predicted dataset.
|
661
688
|
"""
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
689
|
+
self.fit(dataset)
|
690
|
+
assert self._sklearn_object is not None
|
691
|
+
return self._sklearn_object.labels_
|
692
|
+
|
693
|
+
|
694
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
695
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
696
|
+
"""
|
697
|
+
Returns:
|
698
|
+
Transformed dataset.
|
699
|
+
"""
|
700
|
+
self.fit(dataset)
|
701
|
+
assert self._sklearn_object is not None
|
702
|
+
return self._sklearn_object.embedding_
|
703
|
+
|
669
704
|
|
670
705
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
671
706
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SpectralCoclustering(BaseTransformer):
|
58
70
|
r"""Spectral Co-Clustering algorithm (Dhillon, 2001)
|
59
71
|
For more details on this class, see [sklearn.cluster.SpectralCoclustering]
|
@@ -166,7 +178,9 @@ class SpectralCoclustering(BaseTransformer):
|
|
166
178
|
self.set_label_cols(label_cols)
|
167
179
|
self.set_passthrough_cols(passthrough_cols)
|
168
180
|
self.set_drop_input_cols(drop_input_cols)
|
169
|
-
self.set_sample_weight_col(sample_weight_col)
|
181
|
+
self.set_sample_weight_col(sample_weight_col)
|
182
|
+
self._use_external_memory_version = False
|
183
|
+
self._batch_size = -1
|
170
184
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
171
185
|
|
172
186
|
self._deps = list(deps)
|
@@ -248,11 +262,6 @@ class SpectralCoclustering(BaseTransformer):
|
|
248
262
|
if isinstance(dataset, DataFrame):
|
249
263
|
session = dataset._session
|
250
264
|
assert session is not None # keep mypy happy
|
251
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
252
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
253
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
254
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
255
|
-
|
256
265
|
# Specify input columns so column pruning will be enforced
|
257
266
|
selected_cols = self._get_active_columns()
|
258
267
|
if len(selected_cols) > 0:
|
@@ -280,7 +289,9 @@ class SpectralCoclustering(BaseTransformer):
|
|
280
289
|
label_cols=self.label_cols,
|
281
290
|
sample_weight_col=self.sample_weight_col,
|
282
291
|
autogenerated=self._autogenerated,
|
283
|
-
subproject=_SUBPROJECT
|
292
|
+
subproject=_SUBPROJECT,
|
293
|
+
use_external_memory_version=self._use_external_memory_version,
|
294
|
+
batch_size=self._batch_size,
|
284
295
|
)
|
285
296
|
self._sklearn_object = model_trainer.train()
|
286
297
|
self._is_fitted = True
|
@@ -549,6 +560,22 @@ class SpectralCoclustering(BaseTransformer):
|
|
549
560
|
# each row containing a list of values.
|
550
561
|
expected_dtype = "ARRAY"
|
551
562
|
|
563
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
564
|
+
if expected_dtype == "":
|
565
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
566
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
567
|
+
expected_dtype = "ARRAY"
|
568
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
569
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
570
|
+
expected_dtype = "ARRAY"
|
571
|
+
else:
|
572
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
573
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
574
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
575
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
576
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
577
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
578
|
+
|
552
579
|
output_df = self._batch_inference(
|
553
580
|
dataset=dataset,
|
554
581
|
inference_method="transform",
|
@@ -564,8 +591,8 @@ class SpectralCoclustering(BaseTransformer):
|
|
564
591
|
|
565
592
|
return output_df
|
566
593
|
|
567
|
-
@available_if(
|
568
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
594
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
595
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
569
596
|
""" Method not supported for this class.
|
570
597
|
|
571
598
|
|
@@ -578,13 +605,21 @@ class SpectralCoclustering(BaseTransformer):
|
|
578
605
|
Returns:
|
579
606
|
Predicted dataset.
|
580
607
|
"""
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
608
|
+
self.fit(dataset)
|
609
|
+
assert self._sklearn_object is not None
|
610
|
+
return self._sklearn_object.labels_
|
611
|
+
|
612
|
+
|
613
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
614
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
615
|
+
"""
|
616
|
+
Returns:
|
617
|
+
Transformed dataset.
|
618
|
+
"""
|
619
|
+
self.fit(dataset)
|
620
|
+
assert self._sklearn_object is not None
|
621
|
+
return self._sklearn_object.embedding_
|
622
|
+
|
588
623
|
|
589
624
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
590
625
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ColumnTransformer(BaseTransformer):
|
58
70
|
r"""Applies transformers to columns of an array or pandas DataFrame
|
59
71
|
For more details on this class, see [sklearn.compose.ColumnTransformer]
|
@@ -196,7 +208,9 @@ class ColumnTransformer(BaseTransformer):
|
|
196
208
|
self.set_label_cols(label_cols)
|
197
209
|
self.set_passthrough_cols(passthrough_cols)
|
198
210
|
self.set_drop_input_cols(drop_input_cols)
|
199
|
-
self.set_sample_weight_col(sample_weight_col)
|
211
|
+
self.set_sample_weight_col(sample_weight_col)
|
212
|
+
self._use_external_memory_version = False
|
213
|
+
self._batch_size = -1
|
200
214
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
201
215
|
deps = deps | gather_dependencies(transformers)
|
202
216
|
self._deps = list(deps)
|
@@ -278,11 +292,6 @@ class ColumnTransformer(BaseTransformer):
|
|
278
292
|
if isinstance(dataset, DataFrame):
|
279
293
|
session = dataset._session
|
280
294
|
assert session is not None # keep mypy happy
|
281
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
282
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
283
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
284
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
285
|
-
|
286
295
|
# Specify input columns so column pruning will be enforced
|
287
296
|
selected_cols = self._get_active_columns()
|
288
297
|
if len(selected_cols) > 0:
|
@@ -310,7 +319,9 @@ class ColumnTransformer(BaseTransformer):
|
|
310
319
|
label_cols=self.label_cols,
|
311
320
|
sample_weight_col=self.sample_weight_col,
|
312
321
|
autogenerated=self._autogenerated,
|
313
|
-
subproject=_SUBPROJECT
|
322
|
+
subproject=_SUBPROJECT,
|
323
|
+
use_external_memory_version=self._use_external_memory_version,
|
324
|
+
batch_size=self._batch_size,
|
314
325
|
)
|
315
326
|
self._sklearn_object = model_trainer.train()
|
316
327
|
self._is_fitted = True
|
@@ -581,6 +592,22 @@ class ColumnTransformer(BaseTransformer):
|
|
581
592
|
# each row containing a list of values.
|
582
593
|
expected_dtype = "ARRAY"
|
583
594
|
|
595
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
596
|
+
if expected_dtype == "":
|
597
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
598
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
599
|
+
expected_dtype = "ARRAY"
|
600
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
601
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
602
|
+
expected_dtype = "ARRAY"
|
603
|
+
else:
|
604
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
605
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
606
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
607
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
608
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
609
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
610
|
+
|
584
611
|
output_df = self._batch_inference(
|
585
612
|
dataset=dataset,
|
586
613
|
inference_method="transform",
|
@@ -596,8 +623,8 @@ class ColumnTransformer(BaseTransformer):
|
|
596
623
|
|
597
624
|
return output_df
|
598
625
|
|
599
|
-
@available_if(
|
600
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
626
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
627
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
601
628
|
""" Method not supported for this class.
|
602
629
|
|
603
630
|
|
@@ -610,13 +637,21 @@ class ColumnTransformer(BaseTransformer):
|
|
610
637
|
Returns:
|
611
638
|
Predicted dataset.
|
612
639
|
"""
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
640
|
+
self.fit(dataset)
|
641
|
+
assert self._sklearn_object is not None
|
642
|
+
return self._sklearn_object.labels_
|
643
|
+
|
644
|
+
|
645
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
646
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
647
|
+
"""
|
648
|
+
Returns:
|
649
|
+
Transformed dataset.
|
650
|
+
"""
|
651
|
+
self.fit(dataset)
|
652
|
+
assert self._sklearn_object is not None
|
653
|
+
return self._sklearn_object.embedding_
|
654
|
+
|
620
655
|
|
621
656
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
622
657
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.compose".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class TransformedTargetRegressor(BaseTransformer):
|
58
70
|
r"""Meta-estimator to regress on a transformed target
|
59
71
|
For more details on this class, see [sklearn.compose.TransformedTargetRegressor]
|
@@ -159,7 +171,9 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
159
171
|
self.set_label_cols(label_cols)
|
160
172
|
self.set_passthrough_cols(passthrough_cols)
|
161
173
|
self.set_drop_input_cols(drop_input_cols)
|
162
|
-
self.set_sample_weight_col(sample_weight_col)
|
174
|
+
self.set_sample_weight_col(sample_weight_col)
|
175
|
+
self._use_external_memory_version = False
|
176
|
+
self._batch_size = -1
|
163
177
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
164
178
|
|
165
179
|
self._deps = list(deps)
|
@@ -239,11 +253,6 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
239
253
|
if isinstance(dataset, DataFrame):
|
240
254
|
session = dataset._session
|
241
255
|
assert session is not None # keep mypy happy
|
242
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
243
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
244
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
245
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
246
|
-
|
247
256
|
# Specify input columns so column pruning will be enforced
|
248
257
|
selected_cols = self._get_active_columns()
|
249
258
|
if len(selected_cols) > 0:
|
@@ -271,7 +280,9 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
271
280
|
label_cols=self.label_cols,
|
272
281
|
sample_weight_col=self.sample_weight_col,
|
273
282
|
autogenerated=self._autogenerated,
|
274
|
-
subproject=_SUBPROJECT
|
283
|
+
subproject=_SUBPROJECT,
|
284
|
+
use_external_memory_version=self._use_external_memory_version,
|
285
|
+
batch_size=self._batch_size,
|
275
286
|
)
|
276
287
|
self._sklearn_object = model_trainer.train()
|
277
288
|
self._is_fitted = True
|
@@ -542,6 +553,22 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
542
553
|
# each row containing a list of values.
|
543
554
|
expected_dtype = "ARRAY"
|
544
555
|
|
556
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
557
|
+
if expected_dtype == "":
|
558
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
559
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
560
|
+
expected_dtype = "ARRAY"
|
561
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
562
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
563
|
+
expected_dtype = "ARRAY"
|
564
|
+
else:
|
565
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
566
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
567
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
568
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
569
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
570
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
571
|
+
|
545
572
|
output_df = self._batch_inference(
|
546
573
|
dataset=dataset,
|
547
574
|
inference_method="transform",
|
@@ -557,8 +584,8 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
557
584
|
|
558
585
|
return output_df
|
559
586
|
|
560
|
-
@available_if(
|
561
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
587
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
588
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
562
589
|
""" Method not supported for this class.
|
563
590
|
|
564
591
|
|
@@ -571,13 +598,21 @@ class TransformedTargetRegressor(BaseTransformer):
|
|
571
598
|
Returns:
|
572
599
|
Predicted dataset.
|
573
600
|
"""
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
601
|
+
self.fit(dataset)
|
602
|
+
assert self._sklearn_object is not None
|
603
|
+
return self._sklearn_object.labels_
|
604
|
+
|
605
|
+
|
606
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
607
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
608
|
+
"""
|
609
|
+
Returns:
|
610
|
+
Transformed dataset.
|
611
|
+
"""
|
612
|
+
self.fit(dataset)
|
613
|
+
assert self._sklearn_object is not None
|
614
|
+
return self._sklearn_object.embedding_
|
615
|
+
|
581
616
|
|
582
617
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
583
618
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class EllipticEnvelope(BaseTransformer):
|
58
70
|
r"""An object for detecting outliers in a Gaussian distributed dataset
|
59
71
|
For more details on this class, see [sklearn.covariance.EllipticEnvelope]
|
@@ -154,7 +166,9 @@ class EllipticEnvelope(BaseTransformer):
|
|
154
166
|
self.set_label_cols(label_cols)
|
155
167
|
self.set_passthrough_cols(passthrough_cols)
|
156
168
|
self.set_drop_input_cols(drop_input_cols)
|
157
|
-
self.set_sample_weight_col(sample_weight_col)
|
169
|
+
self.set_sample_weight_col(sample_weight_col)
|
170
|
+
self._use_external_memory_version = False
|
171
|
+
self._batch_size = -1
|
158
172
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
159
173
|
|
160
174
|
self._deps = list(deps)
|
@@ -234,11 +248,6 @@ class EllipticEnvelope(BaseTransformer):
|
|
234
248
|
if isinstance(dataset, DataFrame):
|
235
249
|
session = dataset._session
|
236
250
|
assert session is not None # keep mypy happy
|
237
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
238
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
239
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
240
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
241
|
-
|
242
251
|
# Specify input columns so column pruning will be enforced
|
243
252
|
selected_cols = self._get_active_columns()
|
244
253
|
if len(selected_cols) > 0:
|
@@ -266,7 +275,9 @@ class EllipticEnvelope(BaseTransformer):
|
|
266
275
|
label_cols=self.label_cols,
|
267
276
|
sample_weight_col=self.sample_weight_col,
|
268
277
|
autogenerated=self._autogenerated,
|
269
|
-
subproject=_SUBPROJECT
|
278
|
+
subproject=_SUBPROJECT,
|
279
|
+
use_external_memory_version=self._use_external_memory_version,
|
280
|
+
batch_size=self._batch_size,
|
270
281
|
)
|
271
282
|
self._sklearn_object = model_trainer.train()
|
272
283
|
self._is_fitted = True
|
@@ -537,6 +548,22 @@ class EllipticEnvelope(BaseTransformer):
|
|
537
548
|
# each row containing a list of values.
|
538
549
|
expected_dtype = "ARRAY"
|
539
550
|
|
551
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
552
|
+
if expected_dtype == "":
|
553
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
554
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
555
|
+
expected_dtype = "ARRAY"
|
556
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
557
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
558
|
+
expected_dtype = "ARRAY"
|
559
|
+
else:
|
560
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
561
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
562
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
563
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
564
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
565
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
566
|
+
|
540
567
|
output_df = self._batch_inference(
|
541
568
|
dataset=dataset,
|
542
569
|
inference_method="transform",
|
@@ -552,8 +579,8 @@ class EllipticEnvelope(BaseTransformer):
|
|
552
579
|
|
553
580
|
return output_df
|
554
581
|
|
555
|
-
@available_if(
|
556
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
582
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
583
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
557
584
|
""" Perform fit on X and returns labels for X
|
558
585
|
For more details on this function, see [sklearn.covariance.EllipticEnvelope.fit_predict]
|
559
586
|
(https://scikit-learn.org/stable/modules/generated/sklearn.covariance.EllipticEnvelope.html#sklearn.covariance.EllipticEnvelope.fit_predict)
|
@@ -568,13 +595,21 @@ class EllipticEnvelope(BaseTransformer):
|
|
568
595
|
Returns:
|
569
596
|
Predicted dataset.
|
570
597
|
"""
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
598
|
+
self.fit(dataset)
|
599
|
+
assert self._sklearn_object is not None
|
600
|
+
return self._sklearn_object.labels_
|
601
|
+
|
602
|
+
|
603
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
604
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
605
|
+
"""
|
606
|
+
Returns:
|
607
|
+
Transformed dataset.
|
608
|
+
"""
|
609
|
+
self.fit(dataset)
|
610
|
+
assert self._sklearn_object is not None
|
611
|
+
return self._sklearn_object.embedding_
|
612
|
+
|
578
613
|
|
579
614
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
580
615
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|