snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class IncrementalPCA(BaseTransformer):
|
58
70
|
r"""Incremental principal components analysis (IPCA)
|
59
71
|
For more details on this class, see [sklearn.decomposition.IncrementalPCA]
|
@@ -150,7 +162,9 @@ class IncrementalPCA(BaseTransformer):
|
|
150
162
|
self.set_label_cols(label_cols)
|
151
163
|
self.set_passthrough_cols(passthrough_cols)
|
152
164
|
self.set_drop_input_cols(drop_input_cols)
|
153
|
-
self.set_sample_weight_col(sample_weight_col)
|
165
|
+
self.set_sample_weight_col(sample_weight_col)
|
166
|
+
self._use_external_memory_version = False
|
167
|
+
self._batch_size = -1
|
154
168
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
155
169
|
|
156
170
|
self._deps = list(deps)
|
@@ -229,11 +243,6 @@ class IncrementalPCA(BaseTransformer):
|
|
229
243
|
if isinstance(dataset, DataFrame):
|
230
244
|
session = dataset._session
|
231
245
|
assert session is not None # keep mypy happy
|
232
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
233
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
234
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
235
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
236
|
-
|
237
246
|
# Specify input columns so column pruning will be enforced
|
238
247
|
selected_cols = self._get_active_columns()
|
239
248
|
if len(selected_cols) > 0:
|
@@ -261,7 +270,9 @@ class IncrementalPCA(BaseTransformer):
|
|
261
270
|
label_cols=self.label_cols,
|
262
271
|
sample_weight_col=self.sample_weight_col,
|
263
272
|
autogenerated=self._autogenerated,
|
264
|
-
subproject=_SUBPROJECT
|
273
|
+
subproject=_SUBPROJECT,
|
274
|
+
use_external_memory_version=self._use_external_memory_version,
|
275
|
+
batch_size=self._batch_size,
|
265
276
|
)
|
266
277
|
self._sklearn_object = model_trainer.train()
|
267
278
|
self._is_fitted = True
|
@@ -532,6 +543,22 @@ class IncrementalPCA(BaseTransformer):
|
|
532
543
|
# each row containing a list of values.
|
533
544
|
expected_dtype = "ARRAY"
|
534
545
|
|
546
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
547
|
+
if expected_dtype == "":
|
548
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
549
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
550
|
+
expected_dtype = "ARRAY"
|
551
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
552
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
553
|
+
expected_dtype = "ARRAY"
|
554
|
+
else:
|
555
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
556
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
557
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
558
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
559
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
560
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
561
|
+
|
535
562
|
output_df = self._batch_inference(
|
536
563
|
dataset=dataset,
|
537
564
|
inference_method="transform",
|
@@ -547,8 +574,8 @@ class IncrementalPCA(BaseTransformer):
|
|
547
574
|
|
548
575
|
return output_df
|
549
576
|
|
550
|
-
@available_if(
|
551
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
577
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
578
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
552
579
|
""" Method not supported for this class.
|
553
580
|
|
554
581
|
|
@@ -561,13 +588,21 @@ class IncrementalPCA(BaseTransformer):
|
|
561
588
|
Returns:
|
562
589
|
Predicted dataset.
|
563
590
|
"""
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
591
|
+
self.fit(dataset)
|
592
|
+
assert self._sklearn_object is not None
|
593
|
+
return self._sklearn_object.labels_
|
594
|
+
|
595
|
+
|
596
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
597
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
598
|
+
"""
|
599
|
+
Returns:
|
600
|
+
Transformed dataset.
|
601
|
+
"""
|
602
|
+
self.fit(dataset)
|
603
|
+
assert self._sklearn_object is not None
|
604
|
+
return self._sklearn_object.embedding_
|
605
|
+
|
571
606
|
|
572
607
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
573
608
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class KernelPCA(BaseTransformer):
|
58
70
|
r"""Kernel Principal component analysis (KPCA) [1]_
|
59
71
|
For more details on this class, see [sklearn.decomposition.KernelPCA]
|
@@ -234,7 +246,9 @@ class KernelPCA(BaseTransformer):
|
|
234
246
|
self.set_label_cols(label_cols)
|
235
247
|
self.set_passthrough_cols(passthrough_cols)
|
236
248
|
self.set_drop_input_cols(drop_input_cols)
|
237
|
-
self.set_sample_weight_col(sample_weight_col)
|
249
|
+
self.set_sample_weight_col(sample_weight_col)
|
250
|
+
self._use_external_memory_version = False
|
251
|
+
self._batch_size = -1
|
238
252
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
239
253
|
|
240
254
|
self._deps = list(deps)
|
@@ -325,11 +339,6 @@ class KernelPCA(BaseTransformer):
|
|
325
339
|
if isinstance(dataset, DataFrame):
|
326
340
|
session = dataset._session
|
327
341
|
assert session is not None # keep mypy happy
|
328
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
329
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
330
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
331
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
332
|
-
|
333
342
|
# Specify input columns so column pruning will be enforced
|
334
343
|
selected_cols = self._get_active_columns()
|
335
344
|
if len(selected_cols) > 0:
|
@@ -357,7 +366,9 @@ class KernelPCA(BaseTransformer):
|
|
357
366
|
label_cols=self.label_cols,
|
358
367
|
sample_weight_col=self.sample_weight_col,
|
359
368
|
autogenerated=self._autogenerated,
|
360
|
-
subproject=_SUBPROJECT
|
369
|
+
subproject=_SUBPROJECT,
|
370
|
+
use_external_memory_version=self._use_external_memory_version,
|
371
|
+
batch_size=self._batch_size,
|
361
372
|
)
|
362
373
|
self._sklearn_object = model_trainer.train()
|
363
374
|
self._is_fitted = True
|
@@ -628,6 +639,22 @@ class KernelPCA(BaseTransformer):
|
|
628
639
|
# each row containing a list of values.
|
629
640
|
expected_dtype = "ARRAY"
|
630
641
|
|
642
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
643
|
+
if expected_dtype == "":
|
644
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
645
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
646
|
+
expected_dtype = "ARRAY"
|
647
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
648
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
649
|
+
expected_dtype = "ARRAY"
|
650
|
+
else:
|
651
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
652
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
653
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
654
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
655
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
656
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
657
|
+
|
631
658
|
output_df = self._batch_inference(
|
632
659
|
dataset=dataset,
|
633
660
|
inference_method="transform",
|
@@ -643,8 +670,8 @@ class KernelPCA(BaseTransformer):
|
|
643
670
|
|
644
671
|
return output_df
|
645
672
|
|
646
|
-
@available_if(
|
647
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
673
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
674
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
648
675
|
""" Method not supported for this class.
|
649
676
|
|
650
677
|
|
@@ -657,13 +684,21 @@ class KernelPCA(BaseTransformer):
|
|
657
684
|
Returns:
|
658
685
|
Predicted dataset.
|
659
686
|
"""
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
687
|
+
self.fit(dataset)
|
688
|
+
assert self._sklearn_object is not None
|
689
|
+
return self._sklearn_object.labels_
|
690
|
+
|
691
|
+
|
692
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
693
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
694
|
+
"""
|
695
|
+
Returns:
|
696
|
+
Transformed dataset.
|
697
|
+
"""
|
698
|
+
self.fit(dataset)
|
699
|
+
assert self._sklearn_object is not None
|
700
|
+
return self._sklearn_object.embedding_
|
701
|
+
|
667
702
|
|
668
703
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
669
704
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MiniBatchDictionaryLearning(BaseTransformer):
|
58
70
|
r"""Mini-batch dictionary learning
|
59
71
|
For more details on this class, see [sklearn.decomposition.MiniBatchDictionaryLearning]
|
@@ -251,7 +263,9 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
251
263
|
self.set_label_cols(label_cols)
|
252
264
|
self.set_passthrough_cols(passthrough_cols)
|
253
265
|
self.set_drop_input_cols(drop_input_cols)
|
254
|
-
self.set_sample_weight_col(sample_weight_col)
|
266
|
+
self.set_sample_weight_col(sample_weight_col)
|
267
|
+
self._use_external_memory_version = False
|
268
|
+
self._batch_size = -1
|
255
269
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
256
270
|
|
257
271
|
self._deps = list(deps)
|
@@ -347,11 +361,6 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
347
361
|
if isinstance(dataset, DataFrame):
|
348
362
|
session = dataset._session
|
349
363
|
assert session is not None # keep mypy happy
|
350
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
351
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
352
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
353
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
354
|
-
|
355
364
|
# Specify input columns so column pruning will be enforced
|
356
365
|
selected_cols = self._get_active_columns()
|
357
366
|
if len(selected_cols) > 0:
|
@@ -379,7 +388,9 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
379
388
|
label_cols=self.label_cols,
|
380
389
|
sample_weight_col=self.sample_weight_col,
|
381
390
|
autogenerated=self._autogenerated,
|
382
|
-
subproject=_SUBPROJECT
|
391
|
+
subproject=_SUBPROJECT,
|
392
|
+
use_external_memory_version=self._use_external_memory_version,
|
393
|
+
batch_size=self._batch_size,
|
383
394
|
)
|
384
395
|
self._sklearn_object = model_trainer.train()
|
385
396
|
self._is_fitted = True
|
@@ -650,6 +661,22 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
650
661
|
# each row containing a list of values.
|
651
662
|
expected_dtype = "ARRAY"
|
652
663
|
|
664
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
665
|
+
if expected_dtype == "":
|
666
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
667
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
668
|
+
expected_dtype = "ARRAY"
|
669
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
670
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
671
|
+
expected_dtype = "ARRAY"
|
672
|
+
else:
|
673
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
674
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
675
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
676
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
677
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
678
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
679
|
+
|
653
680
|
output_df = self._batch_inference(
|
654
681
|
dataset=dataset,
|
655
682
|
inference_method="transform",
|
@@ -665,8 +692,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
665
692
|
|
666
693
|
return output_df
|
667
694
|
|
668
|
-
@available_if(
|
669
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
695
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
696
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
670
697
|
""" Method not supported for this class.
|
671
698
|
|
672
699
|
|
@@ -679,13 +706,21 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
679
706
|
Returns:
|
680
707
|
Predicted dataset.
|
681
708
|
"""
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
709
|
+
self.fit(dataset)
|
710
|
+
assert self._sklearn_object is not None
|
711
|
+
return self._sklearn_object.labels_
|
712
|
+
|
713
|
+
|
714
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
715
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
716
|
+
"""
|
717
|
+
Returns:
|
718
|
+
Transformed dataset.
|
719
|
+
"""
|
720
|
+
self.fit(dataset)
|
721
|
+
assert self._sklearn_object is not None
|
722
|
+
return self._sklearn_object.embedding_
|
723
|
+
|
689
724
|
|
690
725
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
691
726
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MiniBatchSparsePCA(BaseTransformer):
|
58
70
|
r"""Mini-batch Sparse Principal Components Analysis
|
59
71
|
For more details on this class, see [sklearn.decomposition.MiniBatchSparsePCA]
|
@@ -203,7 +215,9 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
203
215
|
self.set_label_cols(label_cols)
|
204
216
|
self.set_passthrough_cols(passthrough_cols)
|
205
217
|
self.set_drop_input_cols(drop_input_cols)
|
206
|
-
self.set_sample_weight_col(sample_weight_col)
|
218
|
+
self.set_sample_weight_col(sample_weight_col)
|
219
|
+
self._use_external_memory_version = False
|
220
|
+
self._batch_size = -1
|
207
221
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
208
222
|
|
209
223
|
self._deps = list(deps)
|
@@ -292,11 +306,6 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
292
306
|
if isinstance(dataset, DataFrame):
|
293
307
|
session = dataset._session
|
294
308
|
assert session is not None # keep mypy happy
|
295
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
296
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
297
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
298
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
299
|
-
|
300
309
|
# Specify input columns so column pruning will be enforced
|
301
310
|
selected_cols = self._get_active_columns()
|
302
311
|
if len(selected_cols) > 0:
|
@@ -324,7 +333,9 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
324
333
|
label_cols=self.label_cols,
|
325
334
|
sample_weight_col=self.sample_weight_col,
|
326
335
|
autogenerated=self._autogenerated,
|
327
|
-
subproject=_SUBPROJECT
|
336
|
+
subproject=_SUBPROJECT,
|
337
|
+
use_external_memory_version=self._use_external_memory_version,
|
338
|
+
batch_size=self._batch_size,
|
328
339
|
)
|
329
340
|
self._sklearn_object = model_trainer.train()
|
330
341
|
self._is_fitted = True
|
@@ -595,6 +606,22 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
595
606
|
# each row containing a list of values.
|
596
607
|
expected_dtype = "ARRAY"
|
597
608
|
|
609
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
610
|
+
if expected_dtype == "":
|
611
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
612
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
613
|
+
expected_dtype = "ARRAY"
|
614
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
615
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
616
|
+
expected_dtype = "ARRAY"
|
617
|
+
else:
|
618
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
619
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
620
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
621
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
622
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
623
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
624
|
+
|
598
625
|
output_df = self._batch_inference(
|
599
626
|
dataset=dataset,
|
600
627
|
inference_method="transform",
|
@@ -610,8 +637,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
610
637
|
|
611
638
|
return output_df
|
612
639
|
|
613
|
-
@available_if(
|
614
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
640
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
641
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
615
642
|
""" Method not supported for this class.
|
616
643
|
|
617
644
|
|
@@ -624,13 +651,21 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
624
651
|
Returns:
|
625
652
|
Predicted dataset.
|
626
653
|
"""
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
654
|
+
self.fit(dataset)
|
655
|
+
assert self._sklearn_object is not None
|
656
|
+
return self._sklearn_object.labels_
|
657
|
+
|
658
|
+
|
659
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
660
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
661
|
+
"""
|
662
|
+
Returns:
|
663
|
+
Transformed dataset.
|
664
|
+
"""
|
665
|
+
self.fit(dataset)
|
666
|
+
assert self._sklearn_object is not None
|
667
|
+
return self._sklearn_object.embedding_
|
668
|
+
|
634
669
|
|
635
670
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
636
671
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class PCA(BaseTransformer):
|
58
70
|
r"""Principal component analysis (PCA)
|
59
71
|
For more details on this class, see [sklearn.decomposition.PCA]
|
@@ -210,7 +222,9 @@ class PCA(BaseTransformer):
|
|
210
222
|
self.set_label_cols(label_cols)
|
211
223
|
self.set_passthrough_cols(passthrough_cols)
|
212
224
|
self.set_drop_input_cols(drop_input_cols)
|
213
|
-
self.set_sample_weight_col(sample_weight_col)
|
225
|
+
self.set_sample_weight_col(sample_weight_col)
|
226
|
+
self._use_external_memory_version = False
|
227
|
+
self._batch_size = -1
|
214
228
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
215
229
|
|
216
230
|
self._deps = list(deps)
|
@@ -294,11 +308,6 @@ class PCA(BaseTransformer):
|
|
294
308
|
if isinstance(dataset, DataFrame):
|
295
309
|
session = dataset._session
|
296
310
|
assert session is not None # keep mypy happy
|
297
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
298
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
299
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
300
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
301
|
-
|
302
311
|
# Specify input columns so column pruning will be enforced
|
303
312
|
selected_cols = self._get_active_columns()
|
304
313
|
if len(selected_cols) > 0:
|
@@ -326,7 +335,9 @@ class PCA(BaseTransformer):
|
|
326
335
|
label_cols=self.label_cols,
|
327
336
|
sample_weight_col=self.sample_weight_col,
|
328
337
|
autogenerated=self._autogenerated,
|
329
|
-
subproject=_SUBPROJECT
|
338
|
+
subproject=_SUBPROJECT,
|
339
|
+
use_external_memory_version=self._use_external_memory_version,
|
340
|
+
batch_size=self._batch_size,
|
330
341
|
)
|
331
342
|
self._sklearn_object = model_trainer.train()
|
332
343
|
self._is_fitted = True
|
@@ -597,6 +608,22 @@ class PCA(BaseTransformer):
|
|
597
608
|
# each row containing a list of values.
|
598
609
|
expected_dtype = "ARRAY"
|
599
610
|
|
611
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
612
|
+
if expected_dtype == "":
|
613
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
614
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
615
|
+
expected_dtype = "ARRAY"
|
616
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
617
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
618
|
+
expected_dtype = "ARRAY"
|
619
|
+
else:
|
620
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
621
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
622
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
623
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
624
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
625
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
626
|
+
|
600
627
|
output_df = self._batch_inference(
|
601
628
|
dataset=dataset,
|
602
629
|
inference_method="transform",
|
@@ -612,8 +639,8 @@ class PCA(BaseTransformer):
|
|
612
639
|
|
613
640
|
return output_df
|
614
641
|
|
615
|
-
@available_if(
|
616
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
642
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
643
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
617
644
|
""" Method not supported for this class.
|
618
645
|
|
619
646
|
|
@@ -626,13 +653,21 @@ class PCA(BaseTransformer):
|
|
626
653
|
Returns:
|
627
654
|
Predicted dataset.
|
628
655
|
"""
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
656
|
+
self.fit(dataset)
|
657
|
+
assert self._sklearn_object is not None
|
658
|
+
return self._sklearn_object.labels_
|
659
|
+
|
660
|
+
|
661
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
662
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
663
|
+
"""
|
664
|
+
Returns:
|
665
|
+
Transformed dataset.
|
666
|
+
"""
|
667
|
+
self.fit(dataset)
|
668
|
+
assert self._sklearn_object is not None
|
669
|
+
return self._sklearn_object.embedding_
|
670
|
+
|
636
671
|
|
637
672
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
638
673
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|