snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MLPClassifier(BaseTransformer):
|
58
70
|
r"""Multi-layer Perceptron classifier
|
59
71
|
For more details on this class, see [sklearn.neural_network.MLPClassifier]
|
@@ -297,7 +309,9 @@ class MLPClassifier(BaseTransformer):
|
|
297
309
|
self.set_label_cols(label_cols)
|
298
310
|
self.set_passthrough_cols(passthrough_cols)
|
299
311
|
self.set_drop_input_cols(drop_input_cols)
|
300
|
-
self.set_sample_weight_col(sample_weight_col)
|
312
|
+
self.set_sample_weight_col(sample_weight_col)
|
313
|
+
self._use_external_memory_version = False
|
314
|
+
self._batch_size = -1
|
301
315
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
302
316
|
|
303
317
|
self._deps = list(deps)
|
@@ -395,11 +409,6 @@ class MLPClassifier(BaseTransformer):
|
|
395
409
|
if isinstance(dataset, DataFrame):
|
396
410
|
session = dataset._session
|
397
411
|
assert session is not None # keep mypy happy
|
398
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
399
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
400
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
401
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
402
|
-
|
403
412
|
# Specify input columns so column pruning will be enforced
|
404
413
|
selected_cols = self._get_active_columns()
|
405
414
|
if len(selected_cols) > 0:
|
@@ -427,7 +436,9 @@ class MLPClassifier(BaseTransformer):
|
|
427
436
|
label_cols=self.label_cols,
|
428
437
|
sample_weight_col=self.sample_weight_col,
|
429
438
|
autogenerated=self._autogenerated,
|
430
|
-
subproject=_SUBPROJECT
|
439
|
+
subproject=_SUBPROJECT,
|
440
|
+
use_external_memory_version=self._use_external_memory_version,
|
441
|
+
batch_size=self._batch_size,
|
431
442
|
)
|
432
443
|
self._sklearn_object = model_trainer.train()
|
433
444
|
self._is_fitted = True
|
@@ -698,6 +709,22 @@ class MLPClassifier(BaseTransformer):
|
|
698
709
|
# each row containing a list of values.
|
699
710
|
expected_dtype = "ARRAY"
|
700
711
|
|
712
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
713
|
+
if expected_dtype == "":
|
714
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
715
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
716
|
+
expected_dtype = "ARRAY"
|
717
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
718
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
719
|
+
expected_dtype = "ARRAY"
|
720
|
+
else:
|
721
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
722
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
723
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
724
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
725
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
726
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
727
|
+
|
701
728
|
output_df = self._batch_inference(
|
702
729
|
dataset=dataset,
|
703
730
|
inference_method="transform",
|
@@ -713,8 +740,8 @@ class MLPClassifier(BaseTransformer):
|
|
713
740
|
|
714
741
|
return output_df
|
715
742
|
|
716
|
-
@available_if(
|
717
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
743
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
744
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
718
745
|
""" Method not supported for this class.
|
719
746
|
|
720
747
|
|
@@ -727,13 +754,21 @@ class MLPClassifier(BaseTransformer):
|
|
727
754
|
Returns:
|
728
755
|
Predicted dataset.
|
729
756
|
"""
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
757
|
+
self.fit(dataset)
|
758
|
+
assert self._sklearn_object is not None
|
759
|
+
return self._sklearn_object.labels_
|
760
|
+
|
761
|
+
|
762
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
763
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
764
|
+
"""
|
765
|
+
Returns:
|
766
|
+
Transformed dataset.
|
767
|
+
"""
|
768
|
+
self.fit(dataset)
|
769
|
+
assert self._sklearn_object is not None
|
770
|
+
return self._sklearn_object.embedding_
|
771
|
+
|
737
772
|
|
738
773
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
739
774
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MLPRegressor(BaseTransformer):
|
58
70
|
r"""Multi-layer Perceptron regressor
|
59
71
|
For more details on this class, see [sklearn.neural_network.MLPRegressor]
|
@@ -293,7 +305,9 @@ class MLPRegressor(BaseTransformer):
|
|
293
305
|
self.set_label_cols(label_cols)
|
294
306
|
self.set_passthrough_cols(passthrough_cols)
|
295
307
|
self.set_drop_input_cols(drop_input_cols)
|
296
|
-
self.set_sample_weight_col(sample_weight_col)
|
308
|
+
self.set_sample_weight_col(sample_weight_col)
|
309
|
+
self._use_external_memory_version = False
|
310
|
+
self._batch_size = -1
|
297
311
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
298
312
|
|
299
313
|
self._deps = list(deps)
|
@@ -391,11 +405,6 @@ class MLPRegressor(BaseTransformer):
|
|
391
405
|
if isinstance(dataset, DataFrame):
|
392
406
|
session = dataset._session
|
393
407
|
assert session is not None # keep mypy happy
|
394
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
395
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
396
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
397
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
398
|
-
|
399
408
|
# Specify input columns so column pruning will be enforced
|
400
409
|
selected_cols = self._get_active_columns()
|
401
410
|
if len(selected_cols) > 0:
|
@@ -423,7 +432,9 @@ class MLPRegressor(BaseTransformer):
|
|
423
432
|
label_cols=self.label_cols,
|
424
433
|
sample_weight_col=self.sample_weight_col,
|
425
434
|
autogenerated=self._autogenerated,
|
426
|
-
subproject=_SUBPROJECT
|
435
|
+
subproject=_SUBPROJECT,
|
436
|
+
use_external_memory_version=self._use_external_memory_version,
|
437
|
+
batch_size=self._batch_size,
|
427
438
|
)
|
428
439
|
self._sklearn_object = model_trainer.train()
|
429
440
|
self._is_fitted = True
|
@@ -694,6 +705,22 @@ class MLPRegressor(BaseTransformer):
|
|
694
705
|
# each row containing a list of values.
|
695
706
|
expected_dtype = "ARRAY"
|
696
707
|
|
708
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
709
|
+
if expected_dtype == "":
|
710
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
711
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
712
|
+
expected_dtype = "ARRAY"
|
713
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
714
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
715
|
+
expected_dtype = "ARRAY"
|
716
|
+
else:
|
717
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
718
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
719
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
720
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
721
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
722
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
723
|
+
|
697
724
|
output_df = self._batch_inference(
|
698
725
|
dataset=dataset,
|
699
726
|
inference_method="transform",
|
@@ -709,8 +736,8 @@ class MLPRegressor(BaseTransformer):
|
|
709
736
|
|
710
737
|
return output_df
|
711
738
|
|
712
|
-
@available_if(
|
713
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
739
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
740
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
714
741
|
""" Method not supported for this class.
|
715
742
|
|
716
743
|
|
@@ -723,13 +750,21 @@ class MLPRegressor(BaseTransformer):
|
|
723
750
|
Returns:
|
724
751
|
Predicted dataset.
|
725
752
|
"""
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
753
|
+
self.fit(dataset)
|
754
|
+
assert self._sklearn_object is not None
|
755
|
+
return self._sklearn_object.labels_
|
756
|
+
|
757
|
+
|
758
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
759
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
760
|
+
"""
|
761
|
+
Returns:
|
762
|
+
Transformed dataset.
|
763
|
+
"""
|
764
|
+
self.fit(dataset)
|
765
|
+
assert self._sklearn_object is not None
|
766
|
+
return self._sklearn_object.embedding_
|
767
|
+
|
733
768
|
|
734
769
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
735
770
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -8,8 +8,9 @@ from sklearn.preprocessing import _data as sklearn_preprocessing_data
|
|
8
8
|
|
9
9
|
from snowflake import snowpark
|
10
10
|
from snowflake.ml._internal import telemetry
|
11
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
11
12
|
from snowflake.ml.modeling.framework import _utils, base
|
12
|
-
from snowflake.snowpark import functions as F
|
13
|
+
from snowflake.snowpark import functions as F, types as T
|
13
14
|
|
14
15
|
|
15
16
|
class MinMaxScaler(base.BaseTransformer):
|
@@ -125,6 +126,18 @@ class MinMaxScaler(base.BaseTransformer):
|
|
125
126
|
self.data_max_ = {}
|
126
127
|
self.data_range_ = {}
|
127
128
|
|
129
|
+
def _check_input_column_types(self, dataset: snowpark.DataFrame) -> None:
|
130
|
+
for field in dataset.schema.fields:
|
131
|
+
if field.name in self.input_cols:
|
132
|
+
if not issubclass(type(field.datatype), T._NumericType):
|
133
|
+
raise exceptions.SnowflakeMLException(
|
134
|
+
error_code=error_codes.INVALID_DATA_TYPE,
|
135
|
+
original_exception=TypeError(
|
136
|
+
f"Non-numeric input column {field.name} datatype {field.datatype} "
|
137
|
+
"is not supported by the MinMaxScaler."
|
138
|
+
),
|
139
|
+
)
|
140
|
+
|
128
141
|
@telemetry.send_api_usage_telemetry(
|
129
142
|
project=base.PROJECT,
|
130
143
|
subproject=base.SUBPROJECT,
|
@@ -169,6 +182,7 @@ class MinMaxScaler(base.BaseTransformer):
|
|
169
182
|
self.data_range_[input_col] = float(sklearn_scaler.data_range_[i])
|
170
183
|
|
171
184
|
def _fit_snowpark(self, dataset: snowpark.DataFrame) -> None:
|
185
|
+
self._check_input_column_types(dataset)
|
172
186
|
computed_states = self._compute(dataset, self.input_cols, self.custom_states)
|
173
187
|
|
174
188
|
# assign states to the object
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.preprocessing".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class PolynomialFeatures(BaseTransformer):
|
58
70
|
r"""Generate polynomial and interaction features
|
59
71
|
For more details on this class, see [sklearn.preprocessing.PolynomialFeatures]
|
@@ -151,7 +163,9 @@ class PolynomialFeatures(BaseTransformer):
|
|
151
163
|
self.set_label_cols(label_cols)
|
152
164
|
self.set_passthrough_cols(passthrough_cols)
|
153
165
|
self.set_drop_input_cols(drop_input_cols)
|
154
|
-
self.set_sample_weight_col(sample_weight_col)
|
166
|
+
self.set_sample_weight_col(sample_weight_col)
|
167
|
+
self._use_external_memory_version = False
|
168
|
+
self._batch_size = -1
|
155
169
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
156
170
|
|
157
171
|
self._deps = list(deps)
|
@@ -230,11 +244,6 @@ class PolynomialFeatures(BaseTransformer):
|
|
230
244
|
if isinstance(dataset, DataFrame):
|
231
245
|
session = dataset._session
|
232
246
|
assert session is not None # keep mypy happy
|
233
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
234
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
235
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
236
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
237
|
-
|
238
247
|
# Specify input columns so column pruning will be enforced
|
239
248
|
selected_cols = self._get_active_columns()
|
240
249
|
if len(selected_cols) > 0:
|
@@ -262,7 +271,9 @@ class PolynomialFeatures(BaseTransformer):
|
|
262
271
|
label_cols=self.label_cols,
|
263
272
|
sample_weight_col=self.sample_weight_col,
|
264
273
|
autogenerated=self._autogenerated,
|
265
|
-
subproject=_SUBPROJECT
|
274
|
+
subproject=_SUBPROJECT,
|
275
|
+
use_external_memory_version=self._use_external_memory_version,
|
276
|
+
batch_size=self._batch_size,
|
266
277
|
)
|
267
278
|
self._sklearn_object = model_trainer.train()
|
268
279
|
self._is_fitted = True
|
@@ -533,6 +544,22 @@ class PolynomialFeatures(BaseTransformer):
|
|
533
544
|
# each row containing a list of values.
|
534
545
|
expected_dtype = "ARRAY"
|
535
546
|
|
547
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
548
|
+
if expected_dtype == "":
|
549
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
550
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
551
|
+
expected_dtype = "ARRAY"
|
552
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
553
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
554
|
+
expected_dtype = "ARRAY"
|
555
|
+
else:
|
556
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
557
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
558
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
559
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
560
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
561
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
562
|
+
|
536
563
|
output_df = self._batch_inference(
|
537
564
|
dataset=dataset,
|
538
565
|
inference_method="transform",
|
@@ -548,8 +575,8 @@ class PolynomialFeatures(BaseTransformer):
|
|
548
575
|
|
549
576
|
return output_df
|
550
577
|
|
551
|
-
@available_if(
|
552
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
578
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
579
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
553
580
|
""" Method not supported for this class.
|
554
581
|
|
555
582
|
|
@@ -562,13 +589,21 @@ class PolynomialFeatures(BaseTransformer):
|
|
562
589
|
Returns:
|
563
590
|
Predicted dataset.
|
564
591
|
"""
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
592
|
+
self.fit(dataset)
|
593
|
+
assert self._sklearn_object is not None
|
594
|
+
return self._sklearn_object.labels_
|
595
|
+
|
596
|
+
|
597
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
598
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
599
|
+
"""
|
600
|
+
Returns:
|
601
|
+
Transformed dataset.
|
602
|
+
"""
|
603
|
+
self.fit(dataset)
|
604
|
+
assert self._sklearn_object is not None
|
605
|
+
return self._sklearn_object.embedding_
|
606
|
+
|
572
607
|
|
573
608
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
574
609
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.semi_supervised".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LabelPropagation(BaseTransformer):
|
58
70
|
r"""Label Propagation classifier
|
59
71
|
For more details on this class, see [sklearn.semi_supervised.LabelPropagation]
|
@@ -155,7 +167,9 @@ class LabelPropagation(BaseTransformer):
|
|
155
167
|
self.set_label_cols(label_cols)
|
156
168
|
self.set_passthrough_cols(passthrough_cols)
|
157
169
|
self.set_drop_input_cols(drop_input_cols)
|
158
|
-
self.set_sample_weight_col(sample_weight_col)
|
170
|
+
self.set_sample_weight_col(sample_weight_col)
|
171
|
+
self._use_external_memory_version = False
|
172
|
+
self._batch_size = -1
|
159
173
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
160
174
|
|
161
175
|
self._deps = list(deps)
|
@@ -236,11 +250,6 @@ class LabelPropagation(BaseTransformer):
|
|
236
250
|
if isinstance(dataset, DataFrame):
|
237
251
|
session = dataset._session
|
238
252
|
assert session is not None # keep mypy happy
|
239
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
240
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
241
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
242
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
243
|
-
|
244
253
|
# Specify input columns so column pruning will be enforced
|
245
254
|
selected_cols = self._get_active_columns()
|
246
255
|
if len(selected_cols) > 0:
|
@@ -268,7 +277,9 @@ class LabelPropagation(BaseTransformer):
|
|
268
277
|
label_cols=self.label_cols,
|
269
278
|
sample_weight_col=self.sample_weight_col,
|
270
279
|
autogenerated=self._autogenerated,
|
271
|
-
subproject=_SUBPROJECT
|
280
|
+
subproject=_SUBPROJECT,
|
281
|
+
use_external_memory_version=self._use_external_memory_version,
|
282
|
+
batch_size=self._batch_size,
|
272
283
|
)
|
273
284
|
self._sklearn_object = model_trainer.train()
|
274
285
|
self._is_fitted = True
|
@@ -539,6 +550,22 @@ class LabelPropagation(BaseTransformer):
|
|
539
550
|
# each row containing a list of values.
|
540
551
|
expected_dtype = "ARRAY"
|
541
552
|
|
553
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
554
|
+
if expected_dtype == "":
|
555
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
556
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
557
|
+
expected_dtype = "ARRAY"
|
558
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
559
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
560
|
+
expected_dtype = "ARRAY"
|
561
|
+
else:
|
562
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
563
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
564
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
565
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
566
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
567
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
568
|
+
|
542
569
|
output_df = self._batch_inference(
|
543
570
|
dataset=dataset,
|
544
571
|
inference_method="transform",
|
@@ -554,8 +581,8 @@ class LabelPropagation(BaseTransformer):
|
|
554
581
|
|
555
582
|
return output_df
|
556
583
|
|
557
|
-
@available_if(
|
558
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
584
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
585
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
559
586
|
""" Method not supported for this class.
|
560
587
|
|
561
588
|
|
@@ -568,13 +595,21 @@ class LabelPropagation(BaseTransformer):
|
|
568
595
|
Returns:
|
569
596
|
Predicted dataset.
|
570
597
|
"""
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
598
|
+
self.fit(dataset)
|
599
|
+
assert self._sklearn_object is not None
|
600
|
+
return self._sklearn_object.labels_
|
601
|
+
|
602
|
+
|
603
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
604
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
605
|
+
"""
|
606
|
+
Returns:
|
607
|
+
Transformed dataset.
|
608
|
+
"""
|
609
|
+
self.fit(dataset)
|
610
|
+
assert self._sklearn_object is not None
|
611
|
+
return self._sklearn_object.embedding_
|
612
|
+
|
578
613
|
|
579
614
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
580
615
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.semi_supervised".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LabelSpreading(BaseTransformer):
|
58
70
|
r"""LabelSpreading model for semi-supervised learning
|
59
71
|
For more details on this class, see [sklearn.semi_supervised.LabelSpreading]
|
@@ -163,7 +175,9 @@ class LabelSpreading(BaseTransformer):
|
|
163
175
|
self.set_label_cols(label_cols)
|
164
176
|
self.set_passthrough_cols(passthrough_cols)
|
165
177
|
self.set_drop_input_cols(drop_input_cols)
|
166
|
-
self.set_sample_weight_col(sample_weight_col)
|
178
|
+
self.set_sample_weight_col(sample_weight_col)
|
179
|
+
self._use_external_memory_version = False
|
180
|
+
self._batch_size = -1
|
167
181
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
168
182
|
|
169
183
|
self._deps = list(deps)
|
@@ -245,11 +259,6 @@ class LabelSpreading(BaseTransformer):
|
|
245
259
|
if isinstance(dataset, DataFrame):
|
246
260
|
session = dataset._session
|
247
261
|
assert session is not None # keep mypy happy
|
248
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
249
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
250
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
251
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
252
|
-
|
253
262
|
# Specify input columns so column pruning will be enforced
|
254
263
|
selected_cols = self._get_active_columns()
|
255
264
|
if len(selected_cols) > 0:
|
@@ -277,7 +286,9 @@ class LabelSpreading(BaseTransformer):
|
|
277
286
|
label_cols=self.label_cols,
|
278
287
|
sample_weight_col=self.sample_weight_col,
|
279
288
|
autogenerated=self._autogenerated,
|
280
|
-
subproject=_SUBPROJECT
|
289
|
+
subproject=_SUBPROJECT,
|
290
|
+
use_external_memory_version=self._use_external_memory_version,
|
291
|
+
batch_size=self._batch_size,
|
281
292
|
)
|
282
293
|
self._sklearn_object = model_trainer.train()
|
283
294
|
self._is_fitted = True
|
@@ -548,6 +559,22 @@ class LabelSpreading(BaseTransformer):
|
|
548
559
|
# each row containing a list of values.
|
549
560
|
expected_dtype = "ARRAY"
|
550
561
|
|
562
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
563
|
+
if expected_dtype == "":
|
564
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
565
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
566
|
+
expected_dtype = "ARRAY"
|
567
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
568
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
569
|
+
expected_dtype = "ARRAY"
|
570
|
+
else:
|
571
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
572
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
573
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
574
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
575
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
576
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
577
|
+
|
551
578
|
output_df = self._batch_inference(
|
552
579
|
dataset=dataset,
|
553
580
|
inference_method="transform",
|
@@ -563,8 +590,8 @@ class LabelSpreading(BaseTransformer):
|
|
563
590
|
|
564
591
|
return output_df
|
565
592
|
|
566
|
-
@available_if(
|
567
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
593
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
594
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
568
595
|
""" Method not supported for this class.
|
569
596
|
|
570
597
|
|
@@ -577,13 +604,21 @@ class LabelSpreading(BaseTransformer):
|
|
577
604
|
Returns:
|
578
605
|
Predicted dataset.
|
579
606
|
"""
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
607
|
+
self.fit(dataset)
|
608
|
+
assert self._sklearn_object is not None
|
609
|
+
return self._sklearn_object.labels_
|
610
|
+
|
611
|
+
|
612
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
613
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
614
|
+
"""
|
615
|
+
Returns:
|
616
|
+
Transformed dataset.
|
617
|
+
"""
|
618
|
+
self.fit(dataset)
|
619
|
+
assert self._sklearn_object is not None
|
620
|
+
return self._sklearn_object.embedding_
|
621
|
+
|
587
622
|
|
588
623
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
589
624
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|