snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class OAS(BaseTransformer):
|
58
70
|
r"""Oracle Approximating Shrinkage Estimator as proposed in [1]_
|
59
71
|
For more details on this class, see [sklearn.covariance.OAS]
|
@@ -133,7 +145,9 @@ class OAS(BaseTransformer):
|
|
133
145
|
self.set_label_cols(label_cols)
|
134
146
|
self.set_passthrough_cols(passthrough_cols)
|
135
147
|
self.set_drop_input_cols(drop_input_cols)
|
136
|
-
self.set_sample_weight_col(sample_weight_col)
|
148
|
+
self.set_sample_weight_col(sample_weight_col)
|
149
|
+
self._use_external_memory_version = False
|
150
|
+
self._batch_size = -1
|
137
151
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
138
152
|
|
139
153
|
self._deps = list(deps)
|
@@ -210,11 +224,6 @@ class OAS(BaseTransformer):
|
|
210
224
|
if isinstance(dataset, DataFrame):
|
211
225
|
session = dataset._session
|
212
226
|
assert session is not None # keep mypy happy
|
213
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
214
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
215
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
216
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
217
|
-
|
218
227
|
# Specify input columns so column pruning will be enforced
|
219
228
|
selected_cols = self._get_active_columns()
|
220
229
|
if len(selected_cols) > 0:
|
@@ -242,7 +251,9 @@ class OAS(BaseTransformer):
|
|
242
251
|
label_cols=self.label_cols,
|
243
252
|
sample_weight_col=self.sample_weight_col,
|
244
253
|
autogenerated=self._autogenerated,
|
245
|
-
subproject=_SUBPROJECT
|
254
|
+
subproject=_SUBPROJECT,
|
255
|
+
use_external_memory_version=self._use_external_memory_version,
|
256
|
+
batch_size=self._batch_size,
|
246
257
|
)
|
247
258
|
self._sklearn_object = model_trainer.train()
|
248
259
|
self._is_fitted = True
|
@@ -511,6 +522,22 @@ class OAS(BaseTransformer):
|
|
511
522
|
# each row containing a list of values.
|
512
523
|
expected_dtype = "ARRAY"
|
513
524
|
|
525
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
526
|
+
if expected_dtype == "":
|
527
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
528
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
529
|
+
expected_dtype = "ARRAY"
|
530
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
531
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
532
|
+
expected_dtype = "ARRAY"
|
533
|
+
else:
|
534
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
535
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
536
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
537
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
538
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
539
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
540
|
+
|
514
541
|
output_df = self._batch_inference(
|
515
542
|
dataset=dataset,
|
516
543
|
inference_method="transform",
|
@@ -526,8 +553,8 @@ class OAS(BaseTransformer):
|
|
526
553
|
|
527
554
|
return output_df
|
528
555
|
|
529
|
-
@available_if(
|
530
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
556
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
557
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
531
558
|
""" Method not supported for this class.
|
532
559
|
|
533
560
|
|
@@ -540,13 +567,21 @@ class OAS(BaseTransformer):
|
|
540
567
|
Returns:
|
541
568
|
Predicted dataset.
|
542
569
|
"""
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
570
|
+
self.fit(dataset)
|
571
|
+
assert self._sklearn_object is not None
|
572
|
+
return self._sklearn_object.labels_
|
573
|
+
|
574
|
+
|
575
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
576
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
577
|
+
"""
|
578
|
+
Returns:
|
579
|
+
Transformed dataset.
|
580
|
+
"""
|
581
|
+
self.fit(dataset)
|
582
|
+
assert self._sklearn_object is not None
|
583
|
+
return self._sklearn_object.embedding_
|
584
|
+
|
550
585
|
|
551
586
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
552
587
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ShrunkCovariance(BaseTransformer):
|
58
70
|
r"""Covariance estimator with shrinkage
|
59
71
|
For more details on this class, see [sklearn.covariance.ShrunkCovariance]
|
@@ -138,7 +150,9 @@ class ShrunkCovariance(BaseTransformer):
|
|
138
150
|
self.set_label_cols(label_cols)
|
139
151
|
self.set_passthrough_cols(passthrough_cols)
|
140
152
|
self.set_drop_input_cols(drop_input_cols)
|
141
|
-
self.set_sample_weight_col(sample_weight_col)
|
153
|
+
self.set_sample_weight_col(sample_weight_col)
|
154
|
+
self._use_external_memory_version = False
|
155
|
+
self._batch_size = -1
|
142
156
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
143
157
|
|
144
158
|
self._deps = list(deps)
|
@@ -216,11 +230,6 @@ class ShrunkCovariance(BaseTransformer):
|
|
216
230
|
if isinstance(dataset, DataFrame):
|
217
231
|
session = dataset._session
|
218
232
|
assert session is not None # keep mypy happy
|
219
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
220
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
221
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
222
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
223
|
-
|
224
233
|
# Specify input columns so column pruning will be enforced
|
225
234
|
selected_cols = self._get_active_columns()
|
226
235
|
if len(selected_cols) > 0:
|
@@ -248,7 +257,9 @@ class ShrunkCovariance(BaseTransformer):
|
|
248
257
|
label_cols=self.label_cols,
|
249
258
|
sample_weight_col=self.sample_weight_col,
|
250
259
|
autogenerated=self._autogenerated,
|
251
|
-
subproject=_SUBPROJECT
|
260
|
+
subproject=_SUBPROJECT,
|
261
|
+
use_external_memory_version=self._use_external_memory_version,
|
262
|
+
batch_size=self._batch_size,
|
252
263
|
)
|
253
264
|
self._sklearn_object = model_trainer.train()
|
254
265
|
self._is_fitted = True
|
@@ -517,6 +528,22 @@ class ShrunkCovariance(BaseTransformer):
|
|
517
528
|
# each row containing a list of values.
|
518
529
|
expected_dtype = "ARRAY"
|
519
530
|
|
531
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
532
|
+
if expected_dtype == "":
|
533
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
534
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
535
|
+
expected_dtype = "ARRAY"
|
536
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
537
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
538
|
+
expected_dtype = "ARRAY"
|
539
|
+
else:
|
540
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
541
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
542
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
543
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
544
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
545
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
546
|
+
|
520
547
|
output_df = self._batch_inference(
|
521
548
|
dataset=dataset,
|
522
549
|
inference_method="transform",
|
@@ -532,8 +559,8 @@ class ShrunkCovariance(BaseTransformer):
|
|
532
559
|
|
533
560
|
return output_df
|
534
561
|
|
535
|
-
@available_if(
|
536
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
562
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
563
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
537
564
|
""" Method not supported for this class.
|
538
565
|
|
539
566
|
|
@@ -546,13 +573,21 @@ class ShrunkCovariance(BaseTransformer):
|
|
546
573
|
Returns:
|
547
574
|
Predicted dataset.
|
548
575
|
"""
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
576
|
+
self.fit(dataset)
|
577
|
+
assert self._sklearn_object is not None
|
578
|
+
return self._sklearn_object.labels_
|
579
|
+
|
580
|
+
|
581
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
582
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
583
|
+
"""
|
584
|
+
Returns:
|
585
|
+
Transformed dataset.
|
586
|
+
"""
|
587
|
+
self.fit(dataset)
|
588
|
+
assert self._sklearn_object is not None
|
589
|
+
return self._sklearn_object.embedding_
|
590
|
+
|
556
591
|
|
557
592
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
558
593
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class DictionaryLearning(BaseTransformer):
|
58
70
|
r"""Dictionary learning
|
59
71
|
For more details on this class, see [sklearn.decomposition.DictionaryLearning]
|
@@ -229,7 +241,9 @@ class DictionaryLearning(BaseTransformer):
|
|
229
241
|
self.set_label_cols(label_cols)
|
230
242
|
self.set_passthrough_cols(passthrough_cols)
|
231
243
|
self.set_drop_input_cols(drop_input_cols)
|
232
|
-
self.set_sample_weight_col(sample_weight_col)
|
244
|
+
self.set_sample_weight_col(sample_weight_col)
|
245
|
+
self._use_external_memory_version = False
|
246
|
+
self._batch_size = -1
|
233
247
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
234
248
|
|
235
249
|
self._deps = list(deps)
|
@@ -322,11 +336,6 @@ class DictionaryLearning(BaseTransformer):
|
|
322
336
|
if isinstance(dataset, DataFrame):
|
323
337
|
session = dataset._session
|
324
338
|
assert session is not None # keep mypy happy
|
325
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
326
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
327
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
328
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
329
|
-
|
330
339
|
# Specify input columns so column pruning will be enforced
|
331
340
|
selected_cols = self._get_active_columns()
|
332
341
|
if len(selected_cols) > 0:
|
@@ -354,7 +363,9 @@ class DictionaryLearning(BaseTransformer):
|
|
354
363
|
label_cols=self.label_cols,
|
355
364
|
sample_weight_col=self.sample_weight_col,
|
356
365
|
autogenerated=self._autogenerated,
|
357
|
-
subproject=_SUBPROJECT
|
366
|
+
subproject=_SUBPROJECT,
|
367
|
+
use_external_memory_version=self._use_external_memory_version,
|
368
|
+
batch_size=self._batch_size,
|
358
369
|
)
|
359
370
|
self._sklearn_object = model_trainer.train()
|
360
371
|
self._is_fitted = True
|
@@ -625,6 +636,22 @@ class DictionaryLearning(BaseTransformer):
|
|
625
636
|
# each row containing a list of values.
|
626
637
|
expected_dtype = "ARRAY"
|
627
638
|
|
639
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
640
|
+
if expected_dtype == "":
|
641
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
642
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
643
|
+
expected_dtype = "ARRAY"
|
644
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
645
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
646
|
+
expected_dtype = "ARRAY"
|
647
|
+
else:
|
648
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
649
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
650
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
651
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
652
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
653
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
654
|
+
|
628
655
|
output_df = self._batch_inference(
|
629
656
|
dataset=dataset,
|
630
657
|
inference_method="transform",
|
@@ -640,8 +667,8 @@ class DictionaryLearning(BaseTransformer):
|
|
640
667
|
|
641
668
|
return output_df
|
642
669
|
|
643
|
-
@available_if(
|
644
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
670
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
671
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
645
672
|
""" Method not supported for this class.
|
646
673
|
|
647
674
|
|
@@ -654,13 +681,21 @@ class DictionaryLearning(BaseTransformer):
|
|
654
681
|
Returns:
|
655
682
|
Predicted dataset.
|
656
683
|
"""
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
684
|
+
self.fit(dataset)
|
685
|
+
assert self._sklearn_object is not None
|
686
|
+
return self._sklearn_object.labels_
|
687
|
+
|
688
|
+
|
689
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
690
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
691
|
+
"""
|
692
|
+
Returns:
|
693
|
+
Transformed dataset.
|
694
|
+
"""
|
695
|
+
self.fit(dataset)
|
696
|
+
assert self._sklearn_object is not None
|
697
|
+
return self._sklearn_object.embedding_
|
698
|
+
|
664
699
|
|
665
700
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
666
701
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class FactorAnalysis(BaseTransformer):
|
58
70
|
r"""Factor Analysis (FA)
|
59
71
|
For more details on this class, see [sklearn.decomposition.FactorAnalysis]
|
@@ -175,7 +187,9 @@ class FactorAnalysis(BaseTransformer):
|
|
175
187
|
self.set_label_cols(label_cols)
|
176
188
|
self.set_passthrough_cols(passthrough_cols)
|
177
189
|
self.set_drop_input_cols(drop_input_cols)
|
178
|
-
self.set_sample_weight_col(sample_weight_col)
|
190
|
+
self.set_sample_weight_col(sample_weight_col)
|
191
|
+
self._use_external_memory_version = False
|
192
|
+
self._batch_size = -1
|
179
193
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
180
194
|
|
181
195
|
self._deps = list(deps)
|
@@ -259,11 +273,6 @@ class FactorAnalysis(BaseTransformer):
|
|
259
273
|
if isinstance(dataset, DataFrame):
|
260
274
|
session = dataset._session
|
261
275
|
assert session is not None # keep mypy happy
|
262
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
263
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
264
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
265
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
266
|
-
|
267
276
|
# Specify input columns so column pruning will be enforced
|
268
277
|
selected_cols = self._get_active_columns()
|
269
278
|
if len(selected_cols) > 0:
|
@@ -291,7 +300,9 @@ class FactorAnalysis(BaseTransformer):
|
|
291
300
|
label_cols=self.label_cols,
|
292
301
|
sample_weight_col=self.sample_weight_col,
|
293
302
|
autogenerated=self._autogenerated,
|
294
|
-
subproject=_SUBPROJECT
|
303
|
+
subproject=_SUBPROJECT,
|
304
|
+
use_external_memory_version=self._use_external_memory_version,
|
305
|
+
batch_size=self._batch_size,
|
295
306
|
)
|
296
307
|
self._sklearn_object = model_trainer.train()
|
297
308
|
self._is_fitted = True
|
@@ -562,6 +573,22 @@ class FactorAnalysis(BaseTransformer):
|
|
562
573
|
# each row containing a list of values.
|
563
574
|
expected_dtype = "ARRAY"
|
564
575
|
|
576
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
577
|
+
if expected_dtype == "":
|
578
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
579
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
580
|
+
expected_dtype = "ARRAY"
|
581
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
582
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
583
|
+
expected_dtype = "ARRAY"
|
584
|
+
else:
|
585
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
586
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
587
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
588
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
589
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
590
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
591
|
+
|
565
592
|
output_df = self._batch_inference(
|
566
593
|
dataset=dataset,
|
567
594
|
inference_method="transform",
|
@@ -577,8 +604,8 @@ class FactorAnalysis(BaseTransformer):
|
|
577
604
|
|
578
605
|
return output_df
|
579
606
|
|
580
|
-
@available_if(
|
581
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
607
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
608
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
582
609
|
""" Method not supported for this class.
|
583
610
|
|
584
611
|
|
@@ -591,13 +618,21 @@ class FactorAnalysis(BaseTransformer):
|
|
591
618
|
Returns:
|
592
619
|
Predicted dataset.
|
593
620
|
"""
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
621
|
+
self.fit(dataset)
|
622
|
+
assert self._sklearn_object is not None
|
623
|
+
return self._sklearn_object.labels_
|
624
|
+
|
625
|
+
|
626
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
627
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
628
|
+
"""
|
629
|
+
Returns:
|
630
|
+
Transformed dataset.
|
631
|
+
"""
|
632
|
+
self.fit(dataset)
|
633
|
+
assert self._sklearn_object is not None
|
634
|
+
return self._sklearn_object.embedding_
|
635
|
+
|
601
636
|
|
602
637
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
603
638
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class FastICA(BaseTransformer):
|
58
70
|
r"""FastICA: a fast algorithm for Independent Component Analysis
|
59
71
|
For more details on this class, see [sklearn.decomposition.FastICA]
|
@@ -192,7 +204,9 @@ class FastICA(BaseTransformer):
|
|
192
204
|
self.set_label_cols(label_cols)
|
193
205
|
self.set_passthrough_cols(passthrough_cols)
|
194
206
|
self.set_drop_input_cols(drop_input_cols)
|
195
|
-
self.set_sample_weight_col(sample_weight_col)
|
207
|
+
self.set_sample_weight_col(sample_weight_col)
|
208
|
+
self._use_external_memory_version = False
|
209
|
+
self._batch_size = -1
|
196
210
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
197
211
|
|
198
212
|
self._deps = list(deps)
|
@@ -277,11 +291,6 @@ class FastICA(BaseTransformer):
|
|
277
291
|
if isinstance(dataset, DataFrame):
|
278
292
|
session = dataset._session
|
279
293
|
assert session is not None # keep mypy happy
|
280
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
281
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
282
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
283
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
284
|
-
|
285
294
|
# Specify input columns so column pruning will be enforced
|
286
295
|
selected_cols = self._get_active_columns()
|
287
296
|
if len(selected_cols) > 0:
|
@@ -309,7 +318,9 @@ class FastICA(BaseTransformer):
|
|
309
318
|
label_cols=self.label_cols,
|
310
319
|
sample_weight_col=self.sample_weight_col,
|
311
320
|
autogenerated=self._autogenerated,
|
312
|
-
subproject=_SUBPROJECT
|
321
|
+
subproject=_SUBPROJECT,
|
322
|
+
use_external_memory_version=self._use_external_memory_version,
|
323
|
+
batch_size=self._batch_size,
|
313
324
|
)
|
314
325
|
self._sklearn_object = model_trainer.train()
|
315
326
|
self._is_fitted = True
|
@@ -580,6 +591,22 @@ class FastICA(BaseTransformer):
|
|
580
591
|
# each row containing a list of values.
|
581
592
|
expected_dtype = "ARRAY"
|
582
593
|
|
594
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
595
|
+
if expected_dtype == "":
|
596
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
597
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
598
|
+
expected_dtype = "ARRAY"
|
599
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
600
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
601
|
+
expected_dtype = "ARRAY"
|
602
|
+
else:
|
603
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
604
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
605
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
606
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
607
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
608
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
609
|
+
|
583
610
|
output_df = self._batch_inference(
|
584
611
|
dataset=dataset,
|
585
612
|
inference_method="transform",
|
@@ -595,8 +622,8 @@ class FastICA(BaseTransformer):
|
|
595
622
|
|
596
623
|
return output_df
|
597
624
|
|
598
|
-
@available_if(
|
599
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
625
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
626
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
600
627
|
""" Method not supported for this class.
|
601
628
|
|
602
629
|
|
@@ -609,13 +636,21 @@ class FastICA(BaseTransformer):
|
|
609
636
|
Returns:
|
610
637
|
Predicted dataset.
|
611
638
|
"""
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
639
|
+
self.fit(dataset)
|
640
|
+
assert self._sklearn_object is not None
|
641
|
+
return self._sklearn_object.labels_
|
642
|
+
|
643
|
+
|
644
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
645
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
646
|
+
"""
|
647
|
+
Returns:
|
648
|
+
Transformed dataset.
|
649
|
+
"""
|
650
|
+
self.fit(dataset)
|
651
|
+
assert self._sklearn_object is not None
|
652
|
+
return self._sklearn_object.embedding_
|
653
|
+
|
619
654
|
|
620
655
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
621
656
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|