snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class EmpiricalCovariance(BaseTransformer):
|
58
70
|
r"""Maximum likelihood covariance estimator
|
59
71
|
For more details on this class, see [sklearn.covariance.EmpiricalCovariance]
|
@@ -133,7 +145,9 @@ class EmpiricalCovariance(BaseTransformer):
|
|
133
145
|
self.set_label_cols(label_cols)
|
134
146
|
self.set_passthrough_cols(passthrough_cols)
|
135
147
|
self.set_drop_input_cols(drop_input_cols)
|
136
|
-
self.set_sample_weight_col(sample_weight_col)
|
148
|
+
self.set_sample_weight_col(sample_weight_col)
|
149
|
+
self._use_external_memory_version = False
|
150
|
+
self._batch_size = -1
|
137
151
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
138
152
|
|
139
153
|
self._deps = list(deps)
|
@@ -210,11 +224,6 @@ class EmpiricalCovariance(BaseTransformer):
|
|
210
224
|
if isinstance(dataset, DataFrame):
|
211
225
|
session = dataset._session
|
212
226
|
assert session is not None # keep mypy happy
|
213
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
214
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
215
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
216
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
217
|
-
|
218
227
|
# Specify input columns so column pruning will be enforced
|
219
228
|
selected_cols = self._get_active_columns()
|
220
229
|
if len(selected_cols) > 0:
|
@@ -242,7 +251,9 @@ class EmpiricalCovariance(BaseTransformer):
|
|
242
251
|
label_cols=self.label_cols,
|
243
252
|
sample_weight_col=self.sample_weight_col,
|
244
253
|
autogenerated=self._autogenerated,
|
245
|
-
subproject=_SUBPROJECT
|
254
|
+
subproject=_SUBPROJECT,
|
255
|
+
use_external_memory_version=self._use_external_memory_version,
|
256
|
+
batch_size=self._batch_size,
|
246
257
|
)
|
247
258
|
self._sklearn_object = model_trainer.train()
|
248
259
|
self._is_fitted = True
|
@@ -511,6 +522,22 @@ class EmpiricalCovariance(BaseTransformer):
|
|
511
522
|
# each row containing a list of values.
|
512
523
|
expected_dtype = "ARRAY"
|
513
524
|
|
525
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
526
|
+
if expected_dtype == "":
|
527
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
528
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
529
|
+
expected_dtype = "ARRAY"
|
530
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
531
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
532
|
+
expected_dtype = "ARRAY"
|
533
|
+
else:
|
534
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
535
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
536
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
537
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
538
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
539
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
540
|
+
|
514
541
|
output_df = self._batch_inference(
|
515
542
|
dataset=dataset,
|
516
543
|
inference_method="transform",
|
@@ -526,8 +553,8 @@ class EmpiricalCovariance(BaseTransformer):
|
|
526
553
|
|
527
554
|
return output_df
|
528
555
|
|
529
|
-
@available_if(
|
530
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
556
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
557
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
531
558
|
""" Method not supported for this class.
|
532
559
|
|
533
560
|
|
@@ -540,13 +567,21 @@ class EmpiricalCovariance(BaseTransformer):
|
|
540
567
|
Returns:
|
541
568
|
Predicted dataset.
|
542
569
|
"""
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
570
|
+
self.fit(dataset)
|
571
|
+
assert self._sklearn_object is not None
|
572
|
+
return self._sklearn_object.labels_
|
573
|
+
|
574
|
+
|
575
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
576
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
577
|
+
"""
|
578
|
+
Returns:
|
579
|
+
Transformed dataset.
|
580
|
+
"""
|
581
|
+
self.fit(dataset)
|
582
|
+
assert self._sklearn_object is not None
|
583
|
+
return self._sklearn_object.embedding_
|
584
|
+
|
550
585
|
|
551
586
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
552
587
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GraphicalLasso(BaseTransformer):
|
58
70
|
r"""Sparse inverse covariance estimation with an l1-penalized estimator
|
59
71
|
For more details on this class, see [sklearn.covariance.GraphicalLasso]
|
@@ -174,7 +186,9 @@ class GraphicalLasso(BaseTransformer):
|
|
174
186
|
self.set_label_cols(label_cols)
|
175
187
|
self.set_passthrough_cols(passthrough_cols)
|
176
188
|
self.set_drop_input_cols(drop_input_cols)
|
177
|
-
self.set_sample_weight_col(sample_weight_col)
|
189
|
+
self.set_sample_weight_col(sample_weight_col)
|
190
|
+
self._use_external_memory_version = False
|
191
|
+
self._batch_size = -1
|
178
192
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
179
193
|
|
180
194
|
self._deps = list(deps)
|
@@ -258,11 +272,6 @@ class GraphicalLasso(BaseTransformer):
|
|
258
272
|
if isinstance(dataset, DataFrame):
|
259
273
|
session = dataset._session
|
260
274
|
assert session is not None # keep mypy happy
|
261
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
262
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
263
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
264
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
265
|
-
|
266
275
|
# Specify input columns so column pruning will be enforced
|
267
276
|
selected_cols = self._get_active_columns()
|
268
277
|
if len(selected_cols) > 0:
|
@@ -290,7 +299,9 @@ class GraphicalLasso(BaseTransformer):
|
|
290
299
|
label_cols=self.label_cols,
|
291
300
|
sample_weight_col=self.sample_weight_col,
|
292
301
|
autogenerated=self._autogenerated,
|
293
|
-
subproject=_SUBPROJECT
|
302
|
+
subproject=_SUBPROJECT,
|
303
|
+
use_external_memory_version=self._use_external_memory_version,
|
304
|
+
batch_size=self._batch_size,
|
294
305
|
)
|
295
306
|
self._sklearn_object = model_trainer.train()
|
296
307
|
self._is_fitted = True
|
@@ -559,6 +570,22 @@ class GraphicalLasso(BaseTransformer):
|
|
559
570
|
# each row containing a list of values.
|
560
571
|
expected_dtype = "ARRAY"
|
561
572
|
|
573
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
574
|
+
if expected_dtype == "":
|
575
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
576
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
577
|
+
expected_dtype = "ARRAY"
|
578
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
579
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
580
|
+
expected_dtype = "ARRAY"
|
581
|
+
else:
|
582
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
583
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
584
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
585
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
586
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
587
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
588
|
+
|
562
589
|
output_df = self._batch_inference(
|
563
590
|
dataset=dataset,
|
564
591
|
inference_method="transform",
|
@@ -574,8 +601,8 @@ class GraphicalLasso(BaseTransformer):
|
|
574
601
|
|
575
602
|
return output_df
|
576
603
|
|
577
|
-
@available_if(
|
578
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
604
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
605
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
579
606
|
""" Method not supported for this class.
|
580
607
|
|
581
608
|
|
@@ -588,13 +615,21 @@ class GraphicalLasso(BaseTransformer):
|
|
588
615
|
Returns:
|
589
616
|
Predicted dataset.
|
590
617
|
"""
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
618
|
+
self.fit(dataset)
|
619
|
+
assert self._sklearn_object is not None
|
620
|
+
return self._sklearn_object.labels_
|
621
|
+
|
622
|
+
|
623
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
624
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
625
|
+
"""
|
626
|
+
Returns:
|
627
|
+
Transformed dataset.
|
628
|
+
"""
|
629
|
+
self.fit(dataset)
|
630
|
+
assert self._sklearn_object is not None
|
631
|
+
return self._sklearn_object.embedding_
|
632
|
+
|
598
633
|
|
599
634
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
600
635
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GraphicalLassoCV(BaseTransformer):
|
58
70
|
r"""Sparse inverse covariance w/ cross-validated choice of the l1 penalty
|
59
71
|
For more details on this class, see [sklearn.covariance.GraphicalLassoCV]
|
@@ -198,7 +210,9 @@ class GraphicalLassoCV(BaseTransformer):
|
|
198
210
|
self.set_label_cols(label_cols)
|
199
211
|
self.set_passthrough_cols(passthrough_cols)
|
200
212
|
self.set_drop_input_cols(drop_input_cols)
|
201
|
-
self.set_sample_weight_col(sample_weight_col)
|
213
|
+
self.set_sample_weight_col(sample_weight_col)
|
214
|
+
self._use_external_memory_version = False
|
215
|
+
self._batch_size = -1
|
202
216
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
203
217
|
|
204
218
|
self._deps = list(deps)
|
@@ -284,11 +298,6 @@ class GraphicalLassoCV(BaseTransformer):
|
|
284
298
|
if isinstance(dataset, DataFrame):
|
285
299
|
session = dataset._session
|
286
300
|
assert session is not None # keep mypy happy
|
287
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
288
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
289
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
290
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
291
|
-
|
292
301
|
# Specify input columns so column pruning will be enforced
|
293
302
|
selected_cols = self._get_active_columns()
|
294
303
|
if len(selected_cols) > 0:
|
@@ -316,7 +325,9 @@ class GraphicalLassoCV(BaseTransformer):
|
|
316
325
|
label_cols=self.label_cols,
|
317
326
|
sample_weight_col=self.sample_weight_col,
|
318
327
|
autogenerated=self._autogenerated,
|
319
|
-
subproject=_SUBPROJECT
|
328
|
+
subproject=_SUBPROJECT,
|
329
|
+
use_external_memory_version=self._use_external_memory_version,
|
330
|
+
batch_size=self._batch_size,
|
320
331
|
)
|
321
332
|
self._sklearn_object = model_trainer.train()
|
322
333
|
self._is_fitted = True
|
@@ -585,6 +596,22 @@ class GraphicalLassoCV(BaseTransformer):
|
|
585
596
|
# each row containing a list of values.
|
586
597
|
expected_dtype = "ARRAY"
|
587
598
|
|
599
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
600
|
+
if expected_dtype == "":
|
601
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
602
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
603
|
+
expected_dtype = "ARRAY"
|
604
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
605
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
606
|
+
expected_dtype = "ARRAY"
|
607
|
+
else:
|
608
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
609
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
610
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
611
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
612
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
613
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
614
|
+
|
588
615
|
output_df = self._batch_inference(
|
589
616
|
dataset=dataset,
|
590
617
|
inference_method="transform",
|
@@ -600,8 +627,8 @@ class GraphicalLassoCV(BaseTransformer):
|
|
600
627
|
|
601
628
|
return output_df
|
602
629
|
|
603
|
-
@available_if(
|
604
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
630
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
631
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
605
632
|
""" Method not supported for this class.
|
606
633
|
|
607
634
|
|
@@ -614,13 +641,21 @@ class GraphicalLassoCV(BaseTransformer):
|
|
614
641
|
Returns:
|
615
642
|
Predicted dataset.
|
616
643
|
"""
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
644
|
+
self.fit(dataset)
|
645
|
+
assert self._sklearn_object is not None
|
646
|
+
return self._sklearn_object.labels_
|
647
|
+
|
648
|
+
|
649
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
650
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
651
|
+
"""
|
652
|
+
Returns:
|
653
|
+
Transformed dataset.
|
654
|
+
"""
|
655
|
+
self.fit(dataset)
|
656
|
+
assert self._sklearn_object is not None
|
657
|
+
return self._sklearn_object.embedding_
|
658
|
+
|
624
659
|
|
625
660
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
626
661
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LedoitWolf(BaseTransformer):
|
58
70
|
r"""LedoitWolf Estimator
|
59
71
|
For more details on this class, see [sklearn.covariance.LedoitWolf]
|
@@ -139,7 +151,9 @@ class LedoitWolf(BaseTransformer):
|
|
139
151
|
self.set_label_cols(label_cols)
|
140
152
|
self.set_passthrough_cols(passthrough_cols)
|
141
153
|
self.set_drop_input_cols(drop_input_cols)
|
142
|
-
self.set_sample_weight_col(sample_weight_col)
|
154
|
+
self.set_sample_weight_col(sample_weight_col)
|
155
|
+
self._use_external_memory_version = False
|
156
|
+
self._batch_size = -1
|
143
157
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
144
158
|
|
145
159
|
self._deps = list(deps)
|
@@ -217,11 +231,6 @@ class LedoitWolf(BaseTransformer):
|
|
217
231
|
if isinstance(dataset, DataFrame):
|
218
232
|
session = dataset._session
|
219
233
|
assert session is not None # keep mypy happy
|
220
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
221
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
222
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
223
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
224
|
-
|
225
234
|
# Specify input columns so column pruning will be enforced
|
226
235
|
selected_cols = self._get_active_columns()
|
227
236
|
if len(selected_cols) > 0:
|
@@ -249,7 +258,9 @@ class LedoitWolf(BaseTransformer):
|
|
249
258
|
label_cols=self.label_cols,
|
250
259
|
sample_weight_col=self.sample_weight_col,
|
251
260
|
autogenerated=self._autogenerated,
|
252
|
-
subproject=_SUBPROJECT
|
261
|
+
subproject=_SUBPROJECT,
|
262
|
+
use_external_memory_version=self._use_external_memory_version,
|
263
|
+
batch_size=self._batch_size,
|
253
264
|
)
|
254
265
|
self._sklearn_object = model_trainer.train()
|
255
266
|
self._is_fitted = True
|
@@ -518,6 +529,22 @@ class LedoitWolf(BaseTransformer):
|
|
518
529
|
# each row containing a list of values.
|
519
530
|
expected_dtype = "ARRAY"
|
520
531
|
|
532
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
533
|
+
if expected_dtype == "":
|
534
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
535
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
536
|
+
expected_dtype = "ARRAY"
|
537
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
538
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
539
|
+
expected_dtype = "ARRAY"
|
540
|
+
else:
|
541
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
542
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
543
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
544
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
545
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
546
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
547
|
+
|
521
548
|
output_df = self._batch_inference(
|
522
549
|
dataset=dataset,
|
523
550
|
inference_method="transform",
|
@@ -533,8 +560,8 @@ class LedoitWolf(BaseTransformer):
|
|
533
560
|
|
534
561
|
return output_df
|
535
562
|
|
536
|
-
@available_if(
|
537
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
563
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
564
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
538
565
|
""" Method not supported for this class.
|
539
566
|
|
540
567
|
|
@@ -547,13 +574,21 @@ class LedoitWolf(BaseTransformer):
|
|
547
574
|
Returns:
|
548
575
|
Predicted dataset.
|
549
576
|
"""
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
577
|
+
self.fit(dataset)
|
578
|
+
assert self._sklearn_object is not None
|
579
|
+
return self._sklearn_object.labels_
|
580
|
+
|
581
|
+
|
582
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
583
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
584
|
+
"""
|
585
|
+
Returns:
|
586
|
+
Transformed dataset.
|
587
|
+
"""
|
588
|
+
self.fit(dataset)
|
589
|
+
assert self._sklearn_object is not None
|
590
|
+
return self._sklearn_object.embedding_
|
591
|
+
|
557
592
|
|
558
593
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
559
594
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MinCovDet(BaseTransformer):
|
58
70
|
r"""Minimum Covariance Determinant (MCD): robust estimator of covariance
|
59
71
|
For more details on this class, see [sklearn.covariance.MinCovDet]
|
@@ -150,7 +162,9 @@ class MinCovDet(BaseTransformer):
|
|
150
162
|
self.set_label_cols(label_cols)
|
151
163
|
self.set_passthrough_cols(passthrough_cols)
|
152
164
|
self.set_drop_input_cols(drop_input_cols)
|
153
|
-
self.set_sample_weight_col(sample_weight_col)
|
165
|
+
self.set_sample_weight_col(sample_weight_col)
|
166
|
+
self._use_external_memory_version = False
|
167
|
+
self._batch_size = -1
|
154
168
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
155
169
|
|
156
170
|
self._deps = list(deps)
|
@@ -229,11 +243,6 @@ class MinCovDet(BaseTransformer):
|
|
229
243
|
if isinstance(dataset, DataFrame):
|
230
244
|
session = dataset._session
|
231
245
|
assert session is not None # keep mypy happy
|
232
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
233
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
234
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
235
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
236
|
-
|
237
246
|
# Specify input columns so column pruning will be enforced
|
238
247
|
selected_cols = self._get_active_columns()
|
239
248
|
if len(selected_cols) > 0:
|
@@ -261,7 +270,9 @@ class MinCovDet(BaseTransformer):
|
|
261
270
|
label_cols=self.label_cols,
|
262
271
|
sample_weight_col=self.sample_weight_col,
|
263
272
|
autogenerated=self._autogenerated,
|
264
|
-
subproject=_SUBPROJECT
|
273
|
+
subproject=_SUBPROJECT,
|
274
|
+
use_external_memory_version=self._use_external_memory_version,
|
275
|
+
batch_size=self._batch_size,
|
265
276
|
)
|
266
277
|
self._sklearn_object = model_trainer.train()
|
267
278
|
self._is_fitted = True
|
@@ -530,6 +541,22 @@ class MinCovDet(BaseTransformer):
|
|
530
541
|
# each row containing a list of values.
|
531
542
|
expected_dtype = "ARRAY"
|
532
543
|
|
544
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
545
|
+
if expected_dtype == "":
|
546
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
547
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
548
|
+
expected_dtype = "ARRAY"
|
549
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
550
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
551
|
+
expected_dtype = "ARRAY"
|
552
|
+
else:
|
553
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
554
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
555
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
556
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
557
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
558
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
559
|
+
|
533
560
|
output_df = self._batch_inference(
|
534
561
|
dataset=dataset,
|
535
562
|
inference_method="transform",
|
@@ -545,8 +572,8 @@ class MinCovDet(BaseTransformer):
|
|
545
572
|
|
546
573
|
return output_df
|
547
574
|
|
548
|
-
@available_if(
|
549
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
575
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
576
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
550
577
|
""" Method not supported for this class.
|
551
578
|
|
552
579
|
|
@@ -559,13 +586,21 @@ class MinCovDet(BaseTransformer):
|
|
559
586
|
Returns:
|
560
587
|
Predicted dataset.
|
561
588
|
"""
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
589
|
+
self.fit(dataset)
|
590
|
+
assert self._sklearn_object is not None
|
591
|
+
return self._sklearn_object.labels_
|
592
|
+
|
593
|
+
|
594
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
595
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
596
|
+
"""
|
597
|
+
Returns:
|
598
|
+
Transformed dataset.
|
599
|
+
"""
|
600
|
+
self.fit(dataset)
|
601
|
+
assert self._sklearn_object is not None
|
602
|
+
return self._sklearn_object.embedding_
|
603
|
+
|
569
604
|
|
570
605
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
571
606
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|