snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SparsePCA(BaseTransformer):
|
58
70
|
r"""Sparse Principal Components Analysis (SparsePCA)
|
59
71
|
For more details on this class, see [sklearn.decomposition.SparsePCA]
|
@@ -181,7 +193,9 @@ class SparsePCA(BaseTransformer):
|
|
181
193
|
self.set_label_cols(label_cols)
|
182
194
|
self.set_passthrough_cols(passthrough_cols)
|
183
195
|
self.set_drop_input_cols(drop_input_cols)
|
184
|
-
self.set_sample_weight_col(sample_weight_col)
|
196
|
+
self.set_sample_weight_col(sample_weight_col)
|
197
|
+
self._use_external_memory_version = False
|
198
|
+
self._batch_size = -1
|
185
199
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
186
200
|
|
187
201
|
self._deps = list(deps)
|
@@ -267,11 +281,6 @@ class SparsePCA(BaseTransformer):
|
|
267
281
|
if isinstance(dataset, DataFrame):
|
268
282
|
session = dataset._session
|
269
283
|
assert session is not None # keep mypy happy
|
270
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
271
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
272
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
273
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
274
|
-
|
275
284
|
# Specify input columns so column pruning will be enforced
|
276
285
|
selected_cols = self._get_active_columns()
|
277
286
|
if len(selected_cols) > 0:
|
@@ -299,7 +308,9 @@ class SparsePCA(BaseTransformer):
|
|
299
308
|
label_cols=self.label_cols,
|
300
309
|
sample_weight_col=self.sample_weight_col,
|
301
310
|
autogenerated=self._autogenerated,
|
302
|
-
subproject=_SUBPROJECT
|
311
|
+
subproject=_SUBPROJECT,
|
312
|
+
use_external_memory_version=self._use_external_memory_version,
|
313
|
+
batch_size=self._batch_size,
|
303
314
|
)
|
304
315
|
self._sklearn_object = model_trainer.train()
|
305
316
|
self._is_fitted = True
|
@@ -570,6 +581,22 @@ class SparsePCA(BaseTransformer):
|
|
570
581
|
# each row containing a list of values.
|
571
582
|
expected_dtype = "ARRAY"
|
572
583
|
|
584
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
585
|
+
if expected_dtype == "":
|
586
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
587
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
588
|
+
expected_dtype = "ARRAY"
|
589
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
590
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
591
|
+
expected_dtype = "ARRAY"
|
592
|
+
else:
|
593
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
594
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
595
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
596
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
597
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
598
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
599
|
+
|
573
600
|
output_df = self._batch_inference(
|
574
601
|
dataset=dataset,
|
575
602
|
inference_method="transform",
|
@@ -585,8 +612,8 @@ class SparsePCA(BaseTransformer):
|
|
585
612
|
|
586
613
|
return output_df
|
587
614
|
|
588
|
-
@available_if(
|
589
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
615
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
616
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
590
617
|
""" Method not supported for this class.
|
591
618
|
|
592
619
|
|
@@ -599,13 +626,21 @@ class SparsePCA(BaseTransformer):
|
|
599
626
|
Returns:
|
600
627
|
Predicted dataset.
|
601
628
|
"""
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
629
|
+
self.fit(dataset)
|
630
|
+
assert self._sklearn_object is not None
|
631
|
+
return self._sklearn_object.labels_
|
632
|
+
|
633
|
+
|
634
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
635
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
636
|
+
"""
|
637
|
+
Returns:
|
638
|
+
Transformed dataset.
|
639
|
+
"""
|
640
|
+
self.fit(dataset)
|
641
|
+
assert self._sklearn_object is not None
|
642
|
+
return self._sklearn_object.embedding_
|
643
|
+
|
609
644
|
|
610
645
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
611
646
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class TruncatedSVD(BaseTransformer):
|
58
70
|
r"""Dimensionality reduction using truncated SVD (aka LSA)
|
59
71
|
For more details on this class, see [sklearn.decomposition.TruncatedSVD]
|
@@ -166,7 +178,9 @@ class TruncatedSVD(BaseTransformer):
|
|
166
178
|
self.set_label_cols(label_cols)
|
167
179
|
self.set_passthrough_cols(passthrough_cols)
|
168
180
|
self.set_drop_input_cols(drop_input_cols)
|
169
|
-
self.set_sample_weight_col(sample_weight_col)
|
181
|
+
self.set_sample_weight_col(sample_weight_col)
|
182
|
+
self._use_external_memory_version = False
|
183
|
+
self._batch_size = -1
|
170
184
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
171
185
|
|
172
186
|
self._deps = list(deps)
|
@@ -248,11 +262,6 @@ class TruncatedSVD(BaseTransformer):
|
|
248
262
|
if isinstance(dataset, DataFrame):
|
249
263
|
session = dataset._session
|
250
264
|
assert session is not None # keep mypy happy
|
251
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
252
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
253
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
254
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
255
|
-
|
256
265
|
# Specify input columns so column pruning will be enforced
|
257
266
|
selected_cols = self._get_active_columns()
|
258
267
|
if len(selected_cols) > 0:
|
@@ -280,7 +289,9 @@ class TruncatedSVD(BaseTransformer):
|
|
280
289
|
label_cols=self.label_cols,
|
281
290
|
sample_weight_col=self.sample_weight_col,
|
282
291
|
autogenerated=self._autogenerated,
|
283
|
-
subproject=_SUBPROJECT
|
292
|
+
subproject=_SUBPROJECT,
|
293
|
+
use_external_memory_version=self._use_external_memory_version,
|
294
|
+
batch_size=self._batch_size,
|
284
295
|
)
|
285
296
|
self._sklearn_object = model_trainer.train()
|
286
297
|
self._is_fitted = True
|
@@ -551,6 +562,22 @@ class TruncatedSVD(BaseTransformer):
|
|
551
562
|
# each row containing a list of values.
|
552
563
|
expected_dtype = "ARRAY"
|
553
564
|
|
565
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
566
|
+
if expected_dtype == "":
|
567
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
568
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
569
|
+
expected_dtype = "ARRAY"
|
570
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
571
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
572
|
+
expected_dtype = "ARRAY"
|
573
|
+
else:
|
574
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
575
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
576
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
577
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
578
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
579
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
580
|
+
|
554
581
|
output_df = self._batch_inference(
|
555
582
|
dataset=dataset,
|
556
583
|
inference_method="transform",
|
@@ -566,8 +593,8 @@ class TruncatedSVD(BaseTransformer):
|
|
566
593
|
|
567
594
|
return output_df
|
568
595
|
|
569
|
-
@available_if(
|
570
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
596
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
597
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
571
598
|
""" Method not supported for this class.
|
572
599
|
|
573
600
|
|
@@ -580,13 +607,21 @@ class TruncatedSVD(BaseTransformer):
|
|
580
607
|
Returns:
|
581
608
|
Predicted dataset.
|
582
609
|
"""
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
610
|
+
self.fit(dataset)
|
611
|
+
assert self._sklearn_object is not None
|
612
|
+
return self._sklearn_object.labels_
|
613
|
+
|
614
|
+
|
615
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
616
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
617
|
+
"""
|
618
|
+
Returns:
|
619
|
+
Transformed dataset.
|
620
|
+
"""
|
621
|
+
self.fit(dataset)
|
622
|
+
assert self._sklearn_object is not None
|
623
|
+
return self._sklearn_object.embedding_
|
624
|
+
|
590
625
|
|
591
626
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
592
627
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.discriminant_analysis".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LinearDiscriminantAnalysis(BaseTransformer):
|
58
70
|
r"""Linear Discriminant Analysis
|
59
71
|
For more details on this class, see [sklearn.discriminant_analysis.LinearDiscriminantAnalysis]
|
@@ -183,7 +195,9 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
183
195
|
self.set_label_cols(label_cols)
|
184
196
|
self.set_passthrough_cols(passthrough_cols)
|
185
197
|
self.set_drop_input_cols(drop_input_cols)
|
186
|
-
self.set_sample_weight_col(sample_weight_col)
|
198
|
+
self.set_sample_weight_col(sample_weight_col)
|
199
|
+
self._use_external_memory_version = False
|
200
|
+
self._batch_size = -1
|
187
201
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
188
202
|
|
189
203
|
self._deps = list(deps)
|
@@ -265,11 +279,6 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
265
279
|
if isinstance(dataset, DataFrame):
|
266
280
|
session = dataset._session
|
267
281
|
assert session is not None # keep mypy happy
|
268
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
269
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
270
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
271
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
272
|
-
|
273
282
|
# Specify input columns so column pruning will be enforced
|
274
283
|
selected_cols = self._get_active_columns()
|
275
284
|
if len(selected_cols) > 0:
|
@@ -297,7 +306,9 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
297
306
|
label_cols=self.label_cols,
|
298
307
|
sample_weight_col=self.sample_weight_col,
|
299
308
|
autogenerated=self._autogenerated,
|
300
|
-
subproject=_SUBPROJECT
|
309
|
+
subproject=_SUBPROJECT,
|
310
|
+
use_external_memory_version=self._use_external_memory_version,
|
311
|
+
batch_size=self._batch_size,
|
301
312
|
)
|
302
313
|
self._sklearn_object = model_trainer.train()
|
303
314
|
self._is_fitted = True
|
@@ -570,6 +581,22 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
570
581
|
# each row containing a list of values.
|
571
582
|
expected_dtype = "ARRAY"
|
572
583
|
|
584
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
585
|
+
if expected_dtype == "":
|
586
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
587
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
588
|
+
expected_dtype = "ARRAY"
|
589
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
590
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
591
|
+
expected_dtype = "ARRAY"
|
592
|
+
else:
|
593
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
594
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
595
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
596
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
597
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
598
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
599
|
+
|
573
600
|
output_df = self._batch_inference(
|
574
601
|
dataset=dataset,
|
575
602
|
inference_method="transform",
|
@@ -585,8 +612,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
585
612
|
|
586
613
|
return output_df
|
587
614
|
|
588
|
-
@available_if(
|
589
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
615
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
616
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
590
617
|
""" Method not supported for this class.
|
591
618
|
|
592
619
|
|
@@ -599,13 +626,21 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
599
626
|
Returns:
|
600
627
|
Predicted dataset.
|
601
628
|
"""
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
629
|
+
self.fit(dataset)
|
630
|
+
assert self._sklearn_object is not None
|
631
|
+
return self._sklearn_object.labels_
|
632
|
+
|
633
|
+
|
634
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
635
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
636
|
+
"""
|
637
|
+
Returns:
|
638
|
+
Transformed dataset.
|
639
|
+
"""
|
640
|
+
self.fit(dataset)
|
641
|
+
assert self._sklearn_object is not None
|
642
|
+
return self._sklearn_object.embedding_
|
643
|
+
|
609
644
|
|
610
645
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
611
646
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.discriminant_analysis".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class QuadraticDiscriminantAnalysis(BaseTransformer):
|
58
70
|
r"""Quadratic Discriminant Analysis
|
59
71
|
For more details on this class, see [sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis]
|
@@ -148,7 +160,9 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
148
160
|
self.set_label_cols(label_cols)
|
149
161
|
self.set_passthrough_cols(passthrough_cols)
|
150
162
|
self.set_drop_input_cols(drop_input_cols)
|
151
|
-
self.set_sample_weight_col(sample_weight_col)
|
163
|
+
self.set_sample_weight_col(sample_weight_col)
|
164
|
+
self._use_external_memory_version = False
|
165
|
+
self._batch_size = -1
|
152
166
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
153
167
|
|
154
168
|
self._deps = list(deps)
|
@@ -227,11 +241,6 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
227
241
|
if isinstance(dataset, DataFrame):
|
228
242
|
session = dataset._session
|
229
243
|
assert session is not None # keep mypy happy
|
230
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
231
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
232
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
233
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
234
|
-
|
235
244
|
# Specify input columns so column pruning will be enforced
|
236
245
|
selected_cols = self._get_active_columns()
|
237
246
|
if len(selected_cols) > 0:
|
@@ -259,7 +268,9 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
259
268
|
label_cols=self.label_cols,
|
260
269
|
sample_weight_col=self.sample_weight_col,
|
261
270
|
autogenerated=self._autogenerated,
|
262
|
-
subproject=_SUBPROJECT
|
271
|
+
subproject=_SUBPROJECT,
|
272
|
+
use_external_memory_version=self._use_external_memory_version,
|
273
|
+
batch_size=self._batch_size,
|
263
274
|
)
|
264
275
|
self._sklearn_object = model_trainer.train()
|
265
276
|
self._is_fitted = True
|
@@ -530,6 +541,22 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
530
541
|
# each row containing a list of values.
|
531
542
|
expected_dtype = "ARRAY"
|
532
543
|
|
544
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
545
|
+
if expected_dtype == "":
|
546
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
547
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
548
|
+
expected_dtype = "ARRAY"
|
549
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
550
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
551
|
+
expected_dtype = "ARRAY"
|
552
|
+
else:
|
553
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
554
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
555
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
556
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
557
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
558
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
559
|
+
|
533
560
|
output_df = self._batch_inference(
|
534
561
|
dataset=dataset,
|
535
562
|
inference_method="transform",
|
@@ -545,8 +572,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
545
572
|
|
546
573
|
return output_df
|
547
574
|
|
548
|
-
@available_if(
|
549
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
575
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
576
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
550
577
|
""" Method not supported for this class.
|
551
578
|
|
552
579
|
|
@@ -559,13 +586,21 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
559
586
|
Returns:
|
560
587
|
Predicted dataset.
|
561
588
|
"""
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
589
|
+
self.fit(dataset)
|
590
|
+
assert self._sklearn_object is not None
|
591
|
+
return self._sklearn_object.labels_
|
592
|
+
|
593
|
+
|
594
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
595
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
596
|
+
"""
|
597
|
+
Returns:
|
598
|
+
Transformed dataset.
|
599
|
+
"""
|
600
|
+
self.fit(dataset)
|
601
|
+
assert self._sklearn_object is not None
|
602
|
+
return self._sklearn_object.embedding_
|
603
|
+
|
569
604
|
|
570
605
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
571
606
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class AdaBoostClassifier(BaseTransformer):
|
58
70
|
r"""An AdaBoost classifier
|
59
71
|
For more details on this class, see [sklearn.ensemble.AdaBoostClassifier]
|
@@ -169,7 +181,9 @@ class AdaBoostClassifier(BaseTransformer):
|
|
169
181
|
self.set_label_cols(label_cols)
|
170
182
|
self.set_passthrough_cols(passthrough_cols)
|
171
183
|
self.set_drop_input_cols(drop_input_cols)
|
172
|
-
self.set_sample_weight_col(sample_weight_col)
|
184
|
+
self.set_sample_weight_col(sample_weight_col)
|
185
|
+
self._use_external_memory_version = False
|
186
|
+
self._batch_size = -1
|
173
187
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
174
188
|
deps = deps | gather_dependencies(estimator)
|
175
189
|
deps = deps | gather_dependencies(base_estimator)
|
@@ -252,11 +266,6 @@ class AdaBoostClassifier(BaseTransformer):
|
|
252
266
|
if isinstance(dataset, DataFrame):
|
253
267
|
session = dataset._session
|
254
268
|
assert session is not None # keep mypy happy
|
255
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
256
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
257
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
258
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
259
|
-
|
260
269
|
# Specify input columns so column pruning will be enforced
|
261
270
|
selected_cols = self._get_active_columns()
|
262
271
|
if len(selected_cols) > 0:
|
@@ -284,7 +293,9 @@ class AdaBoostClassifier(BaseTransformer):
|
|
284
293
|
label_cols=self.label_cols,
|
285
294
|
sample_weight_col=self.sample_weight_col,
|
286
295
|
autogenerated=self._autogenerated,
|
287
|
-
subproject=_SUBPROJECT
|
296
|
+
subproject=_SUBPROJECT,
|
297
|
+
use_external_memory_version=self._use_external_memory_version,
|
298
|
+
batch_size=self._batch_size,
|
288
299
|
)
|
289
300
|
self._sklearn_object = model_trainer.train()
|
290
301
|
self._is_fitted = True
|
@@ -555,6 +566,22 @@ class AdaBoostClassifier(BaseTransformer):
|
|
555
566
|
# each row containing a list of values.
|
556
567
|
expected_dtype = "ARRAY"
|
557
568
|
|
569
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
570
|
+
if expected_dtype == "":
|
571
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
572
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
573
|
+
expected_dtype = "ARRAY"
|
574
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
575
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
576
|
+
expected_dtype = "ARRAY"
|
577
|
+
else:
|
578
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
579
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
580
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
581
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
582
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
583
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
584
|
+
|
558
585
|
output_df = self._batch_inference(
|
559
586
|
dataset=dataset,
|
560
587
|
inference_method="transform",
|
@@ -570,8 +597,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
570
597
|
|
571
598
|
return output_df
|
572
599
|
|
573
|
-
@available_if(
|
574
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
600
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
601
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
575
602
|
""" Method not supported for this class.
|
576
603
|
|
577
604
|
|
@@ -584,13 +611,21 @@ class AdaBoostClassifier(BaseTransformer):
|
|
584
611
|
Returns:
|
585
612
|
Predicted dataset.
|
586
613
|
"""
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
614
|
+
self.fit(dataset)
|
615
|
+
assert self._sklearn_object is not None
|
616
|
+
return self._sklearn_object.labels_
|
617
|
+
|
618
|
+
|
619
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
620
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
621
|
+
"""
|
622
|
+
Returns:
|
623
|
+
Transformed dataset.
|
624
|
+
"""
|
625
|
+
self.fit(dataset)
|
626
|
+
assert self._sklearn_object is not None
|
627
|
+
return self._sklearn_object.embedding_
|
628
|
+
|
594
629
|
|
595
630
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
596
631
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|