snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class NearestNeighbors(BaseTransformer):
|
58
70
|
r"""Unsupervised learner for implementing neighbor searches
|
59
71
|
For more details on this class, see [sklearn.neighbors.NearestNeighbors]
|
@@ -188,7 +200,9 @@ class NearestNeighbors(BaseTransformer):
|
|
188
200
|
self.set_label_cols(label_cols)
|
189
201
|
self.set_passthrough_cols(passthrough_cols)
|
190
202
|
self.set_drop_input_cols(drop_input_cols)
|
191
|
-
self.set_sample_weight_col(sample_weight_col)
|
203
|
+
self.set_sample_weight_col(sample_weight_col)
|
204
|
+
self._use_external_memory_version = False
|
205
|
+
self._batch_size = -1
|
192
206
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
193
207
|
|
194
208
|
self._deps = list(deps)
|
@@ -271,11 +285,6 @@ class NearestNeighbors(BaseTransformer):
|
|
271
285
|
if isinstance(dataset, DataFrame):
|
272
286
|
session = dataset._session
|
273
287
|
assert session is not None # keep mypy happy
|
274
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
275
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
276
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
277
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
278
|
-
|
279
288
|
# Specify input columns so column pruning will be enforced
|
280
289
|
selected_cols = self._get_active_columns()
|
281
290
|
if len(selected_cols) > 0:
|
@@ -303,7 +312,9 @@ class NearestNeighbors(BaseTransformer):
|
|
303
312
|
label_cols=self.label_cols,
|
304
313
|
sample_weight_col=self.sample_weight_col,
|
305
314
|
autogenerated=self._autogenerated,
|
306
|
-
subproject=_SUBPROJECT
|
315
|
+
subproject=_SUBPROJECT,
|
316
|
+
use_external_memory_version=self._use_external_memory_version,
|
317
|
+
batch_size=self._batch_size,
|
307
318
|
)
|
308
319
|
self._sklearn_object = model_trainer.train()
|
309
320
|
self._is_fitted = True
|
@@ -572,6 +583,22 @@ class NearestNeighbors(BaseTransformer):
|
|
572
583
|
# each row containing a list of values.
|
573
584
|
expected_dtype = "ARRAY"
|
574
585
|
|
586
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
587
|
+
if expected_dtype == "":
|
588
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
589
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
590
|
+
expected_dtype = "ARRAY"
|
591
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
592
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
593
|
+
expected_dtype = "ARRAY"
|
594
|
+
else:
|
595
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
596
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
597
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
598
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
599
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
600
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
601
|
+
|
575
602
|
output_df = self._batch_inference(
|
576
603
|
dataset=dataset,
|
577
604
|
inference_method="transform",
|
@@ -587,8 +614,8 @@ class NearestNeighbors(BaseTransformer):
|
|
587
614
|
|
588
615
|
return output_df
|
589
616
|
|
590
|
-
@available_if(
|
591
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
617
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
618
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
592
619
|
""" Method not supported for this class.
|
593
620
|
|
594
621
|
|
@@ -601,13 +628,21 @@ class NearestNeighbors(BaseTransformer):
|
|
601
628
|
Returns:
|
602
629
|
Predicted dataset.
|
603
630
|
"""
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
631
|
+
self.fit(dataset)
|
632
|
+
assert self._sklearn_object is not None
|
633
|
+
return self._sklearn_object.labels_
|
634
|
+
|
635
|
+
|
636
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
637
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
638
|
+
"""
|
639
|
+
Returns:
|
640
|
+
Transformed dataset.
|
641
|
+
"""
|
642
|
+
self.fit(dataset)
|
643
|
+
assert self._sklearn_object is not None
|
644
|
+
return self._sklearn_object.embedding_
|
645
|
+
|
611
646
|
|
612
647
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
613
648
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class NeighborhoodComponentsAnalysis(BaseTransformer):
|
58
70
|
r"""Neighborhood Components Analysis
|
59
71
|
For more details on this class, see [sklearn.neighbors.NeighborhoodComponentsAnalysis]
|
@@ -209,7 +221,9 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
209
221
|
self.set_label_cols(label_cols)
|
210
222
|
self.set_passthrough_cols(passthrough_cols)
|
211
223
|
self.set_drop_input_cols(drop_input_cols)
|
212
|
-
self.set_sample_weight_col(sample_weight_col)
|
224
|
+
self.set_sample_weight_col(sample_weight_col)
|
225
|
+
self._use_external_memory_version = False
|
226
|
+
self._batch_size = -1
|
213
227
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
214
228
|
|
215
229
|
self._deps = list(deps)
|
@@ -292,11 +306,6 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
292
306
|
if isinstance(dataset, DataFrame):
|
293
307
|
session = dataset._session
|
294
308
|
assert session is not None # keep mypy happy
|
295
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
296
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
297
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
298
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
299
|
-
|
300
309
|
# Specify input columns so column pruning will be enforced
|
301
310
|
selected_cols = self._get_active_columns()
|
302
311
|
if len(selected_cols) > 0:
|
@@ -324,7 +333,9 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
324
333
|
label_cols=self.label_cols,
|
325
334
|
sample_weight_col=self.sample_weight_col,
|
326
335
|
autogenerated=self._autogenerated,
|
327
|
-
subproject=_SUBPROJECT
|
336
|
+
subproject=_SUBPROJECT,
|
337
|
+
use_external_memory_version=self._use_external_memory_version,
|
338
|
+
batch_size=self._batch_size,
|
328
339
|
)
|
329
340
|
self._sklearn_object = model_trainer.train()
|
330
341
|
self._is_fitted = True
|
@@ -595,6 +606,22 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
595
606
|
# each row containing a list of values.
|
596
607
|
expected_dtype = "ARRAY"
|
597
608
|
|
609
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
610
|
+
if expected_dtype == "":
|
611
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
612
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
613
|
+
expected_dtype = "ARRAY"
|
614
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
615
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
616
|
+
expected_dtype = "ARRAY"
|
617
|
+
else:
|
618
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
619
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
620
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
621
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
622
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
623
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
624
|
+
|
598
625
|
output_df = self._batch_inference(
|
599
626
|
dataset=dataset,
|
600
627
|
inference_method="transform",
|
@@ -610,8 +637,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
610
637
|
|
611
638
|
return output_df
|
612
639
|
|
613
|
-
@available_if(
|
614
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
640
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
641
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
615
642
|
""" Method not supported for this class.
|
616
643
|
|
617
644
|
|
@@ -624,13 +651,21 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
|
|
624
651
|
Returns:
|
625
652
|
Predicted dataset.
|
626
653
|
"""
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
654
|
+
self.fit(dataset)
|
655
|
+
assert self._sklearn_object is not None
|
656
|
+
return self._sklearn_object.labels_
|
657
|
+
|
658
|
+
|
659
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
660
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
661
|
+
"""
|
662
|
+
Returns:
|
663
|
+
Transformed dataset.
|
664
|
+
"""
|
665
|
+
self.fit(dataset)
|
666
|
+
assert self._sklearn_object is not None
|
667
|
+
return self._sklearn_object.embedding_
|
668
|
+
|
634
669
|
|
635
670
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
636
671
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RadiusNeighborsClassifier(BaseTransformer):
|
58
70
|
r"""Classifier implementing a vote among neighbors within a given radius
|
59
71
|
For more details on this class, see [sklearn.neighbors.RadiusNeighborsClassifier]
|
@@ -209,7 +221,9 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
209
221
|
self.set_label_cols(label_cols)
|
210
222
|
self.set_passthrough_cols(passthrough_cols)
|
211
223
|
self.set_drop_input_cols(drop_input_cols)
|
212
|
-
self.set_sample_weight_col(sample_weight_col)
|
224
|
+
self.set_sample_weight_col(sample_weight_col)
|
225
|
+
self._use_external_memory_version = False
|
226
|
+
self._batch_size = -1
|
213
227
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
214
228
|
|
215
229
|
self._deps = list(deps)
|
@@ -293,11 +307,6 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
293
307
|
if isinstance(dataset, DataFrame):
|
294
308
|
session = dataset._session
|
295
309
|
assert session is not None # keep mypy happy
|
296
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
297
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
298
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
299
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
300
|
-
|
301
310
|
# Specify input columns so column pruning will be enforced
|
302
311
|
selected_cols = self._get_active_columns()
|
303
312
|
if len(selected_cols) > 0:
|
@@ -325,7 +334,9 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
325
334
|
label_cols=self.label_cols,
|
326
335
|
sample_weight_col=self.sample_weight_col,
|
327
336
|
autogenerated=self._autogenerated,
|
328
|
-
subproject=_SUBPROJECT
|
337
|
+
subproject=_SUBPROJECT,
|
338
|
+
use_external_memory_version=self._use_external_memory_version,
|
339
|
+
batch_size=self._batch_size,
|
329
340
|
)
|
330
341
|
self._sklearn_object = model_trainer.train()
|
331
342
|
self._is_fitted = True
|
@@ -596,6 +607,22 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
596
607
|
# each row containing a list of values.
|
597
608
|
expected_dtype = "ARRAY"
|
598
609
|
|
610
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
611
|
+
if expected_dtype == "":
|
612
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
613
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
614
|
+
expected_dtype = "ARRAY"
|
615
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
616
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
617
|
+
expected_dtype = "ARRAY"
|
618
|
+
else:
|
619
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
620
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
621
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
622
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
623
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
624
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
625
|
+
|
599
626
|
output_df = self._batch_inference(
|
600
627
|
dataset=dataset,
|
601
628
|
inference_method="transform",
|
@@ -611,8 +638,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
611
638
|
|
612
639
|
return output_df
|
613
640
|
|
614
|
-
@available_if(
|
615
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
641
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
642
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
616
643
|
""" Method not supported for this class.
|
617
644
|
|
618
645
|
|
@@ -625,13 +652,21 @@ class RadiusNeighborsClassifier(BaseTransformer):
|
|
625
652
|
Returns:
|
626
653
|
Predicted dataset.
|
627
654
|
"""
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
655
|
+
self.fit(dataset)
|
656
|
+
assert self._sklearn_object is not None
|
657
|
+
return self._sklearn_object.labels_
|
658
|
+
|
659
|
+
|
660
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
661
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
662
|
+
"""
|
663
|
+
Returns:
|
664
|
+
Transformed dataset.
|
665
|
+
"""
|
666
|
+
self.fit(dataset)
|
667
|
+
assert self._sklearn_object is not None
|
668
|
+
return self._sklearn_object.embedding_
|
669
|
+
|
635
670
|
|
636
671
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
637
672
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RadiusNeighborsRegressor(BaseTransformer):
|
58
70
|
r"""Regression based on neighbors within a fixed radius
|
59
71
|
For more details on this class, see [sklearn.neighbors.RadiusNeighborsRegressor]
|
@@ -200,7 +212,9 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
200
212
|
self.set_label_cols(label_cols)
|
201
213
|
self.set_passthrough_cols(passthrough_cols)
|
202
214
|
self.set_drop_input_cols(drop_input_cols)
|
203
|
-
self.set_sample_weight_col(sample_weight_col)
|
215
|
+
self.set_sample_weight_col(sample_weight_col)
|
216
|
+
self._use_external_memory_version = False
|
217
|
+
self._batch_size = -1
|
204
218
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
205
219
|
|
206
220
|
self._deps = list(deps)
|
@@ -283,11 +297,6 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
283
297
|
if isinstance(dataset, DataFrame):
|
284
298
|
session = dataset._session
|
285
299
|
assert session is not None # keep mypy happy
|
286
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
287
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
288
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
289
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
290
|
-
|
291
300
|
# Specify input columns so column pruning will be enforced
|
292
301
|
selected_cols = self._get_active_columns()
|
293
302
|
if len(selected_cols) > 0:
|
@@ -315,7 +324,9 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
315
324
|
label_cols=self.label_cols,
|
316
325
|
sample_weight_col=self.sample_weight_col,
|
317
326
|
autogenerated=self._autogenerated,
|
318
|
-
subproject=_SUBPROJECT
|
327
|
+
subproject=_SUBPROJECT,
|
328
|
+
use_external_memory_version=self._use_external_memory_version,
|
329
|
+
batch_size=self._batch_size,
|
319
330
|
)
|
320
331
|
self._sklearn_object = model_trainer.train()
|
321
332
|
self._is_fitted = True
|
@@ -586,6 +597,22 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
586
597
|
# each row containing a list of values.
|
587
598
|
expected_dtype = "ARRAY"
|
588
599
|
|
600
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
601
|
+
if expected_dtype == "":
|
602
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
603
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
604
|
+
expected_dtype = "ARRAY"
|
605
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
606
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
607
|
+
expected_dtype = "ARRAY"
|
608
|
+
else:
|
609
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
610
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
611
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
612
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
613
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
614
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
615
|
+
|
589
616
|
output_df = self._batch_inference(
|
590
617
|
dataset=dataset,
|
591
618
|
inference_method="transform",
|
@@ -601,8 +628,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
601
628
|
|
602
629
|
return output_df
|
603
630
|
|
604
|
-
@available_if(
|
605
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
631
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
632
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
606
633
|
""" Method not supported for this class.
|
607
634
|
|
608
635
|
|
@@ -615,13 +642,21 @@ class RadiusNeighborsRegressor(BaseTransformer):
|
|
615
642
|
Returns:
|
616
643
|
Predicted dataset.
|
617
644
|
"""
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
645
|
+
self.fit(dataset)
|
646
|
+
assert self._sklearn_object is not None
|
647
|
+
return self._sklearn_object.labels_
|
648
|
+
|
649
|
+
|
650
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
651
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
652
|
+
"""
|
653
|
+
Returns:
|
654
|
+
Transformed dataset.
|
655
|
+
"""
|
656
|
+
self.fit(dataset)
|
657
|
+
assert self._sklearn_object is not None
|
658
|
+
return self._sklearn_object.embedding_
|
659
|
+
|
625
660
|
|
626
661
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
627
662
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class BernoulliRBM(BaseTransformer):
|
58
70
|
r"""Bernoulli Restricted Boltzmann Machine (RBM)
|
59
71
|
For more details on this class, see [sklearn.neural_network.BernoulliRBM]
|
@@ -159,7 +171,9 @@ class BernoulliRBM(BaseTransformer):
|
|
159
171
|
self.set_label_cols(label_cols)
|
160
172
|
self.set_passthrough_cols(passthrough_cols)
|
161
173
|
self.set_drop_input_cols(drop_input_cols)
|
162
|
-
self.set_sample_weight_col(sample_weight_col)
|
174
|
+
self.set_sample_weight_col(sample_weight_col)
|
175
|
+
self._use_external_memory_version = False
|
176
|
+
self._batch_size = -1
|
163
177
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
164
178
|
|
165
179
|
self._deps = list(deps)
|
@@ -240,11 +254,6 @@ class BernoulliRBM(BaseTransformer):
|
|
240
254
|
if isinstance(dataset, DataFrame):
|
241
255
|
session = dataset._session
|
242
256
|
assert session is not None # keep mypy happy
|
243
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
244
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
245
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
246
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
247
|
-
|
248
257
|
# Specify input columns so column pruning will be enforced
|
249
258
|
selected_cols = self._get_active_columns()
|
250
259
|
if len(selected_cols) > 0:
|
@@ -272,7 +281,9 @@ class BernoulliRBM(BaseTransformer):
|
|
272
281
|
label_cols=self.label_cols,
|
273
282
|
sample_weight_col=self.sample_weight_col,
|
274
283
|
autogenerated=self._autogenerated,
|
275
|
-
subproject=_SUBPROJECT
|
284
|
+
subproject=_SUBPROJECT,
|
285
|
+
use_external_memory_version=self._use_external_memory_version,
|
286
|
+
batch_size=self._batch_size,
|
276
287
|
)
|
277
288
|
self._sklearn_object = model_trainer.train()
|
278
289
|
self._is_fitted = True
|
@@ -543,6 +554,22 @@ class BernoulliRBM(BaseTransformer):
|
|
543
554
|
# each row containing a list of values.
|
544
555
|
expected_dtype = "ARRAY"
|
545
556
|
|
557
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
558
|
+
if expected_dtype == "":
|
559
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
560
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
561
|
+
expected_dtype = "ARRAY"
|
562
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
563
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
564
|
+
expected_dtype = "ARRAY"
|
565
|
+
else:
|
566
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
567
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
568
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
569
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
570
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
571
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
572
|
+
|
546
573
|
output_df = self._batch_inference(
|
547
574
|
dataset=dataset,
|
548
575
|
inference_method="transform",
|
@@ -558,8 +585,8 @@ class BernoulliRBM(BaseTransformer):
|
|
558
585
|
|
559
586
|
return output_df
|
560
587
|
|
561
|
-
@available_if(
|
562
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
588
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
589
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
563
590
|
""" Method not supported for this class.
|
564
591
|
|
565
592
|
|
@@ -572,13 +599,21 @@ class BernoulliRBM(BaseTransformer):
|
|
572
599
|
Returns:
|
573
600
|
Predicted dataset.
|
574
601
|
"""
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
602
|
+
self.fit(dataset)
|
603
|
+
assert self._sklearn_object is not None
|
604
|
+
return self._sklearn_object.labels_
|
605
|
+
|
606
|
+
|
607
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
608
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
609
|
+
"""
|
610
|
+
Returns:
|
611
|
+
Transformed dataset.
|
612
|
+
"""
|
613
|
+
self.fit(dataset)
|
614
|
+
assert self._sklearn_object is not None
|
615
|
+
return self._sklearn_object.embedding_
|
616
|
+
|
582
617
|
|
583
618
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
584
619
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|