snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -55,6 +55,18 @@ _PROJECT = "ModelDevelopment"
|
|
55
55
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn.", "").split("_")])
|
56
56
|
|
57
57
|
|
58
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
59
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
60
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
61
|
+
return check
|
62
|
+
|
63
|
+
|
64
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
+
return check
|
68
|
+
|
69
|
+
|
58
70
|
class IterativeImputer(BaseTransformer):
|
59
71
|
r"""Multivariate imputer that estimates each feature from all the others
|
60
72
|
For more details on this class, see [sklearn.impute.IterativeImputer]
|
@@ -241,7 +253,9 @@ class IterativeImputer(BaseTransformer):
|
|
241
253
|
self.set_label_cols(label_cols)
|
242
254
|
self.set_passthrough_cols(passthrough_cols)
|
243
255
|
self.set_drop_input_cols(drop_input_cols)
|
244
|
-
self.set_sample_weight_col(sample_weight_col)
|
256
|
+
self.set_sample_weight_col(sample_weight_col)
|
257
|
+
self._use_external_memory_version = False
|
258
|
+
self._batch_size = -1
|
245
259
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
246
260
|
deps = deps | gather_dependencies(estimator)
|
247
261
|
self._deps = list(deps)
|
@@ -332,11 +346,6 @@ class IterativeImputer(BaseTransformer):
|
|
332
346
|
if isinstance(dataset, DataFrame):
|
333
347
|
session = dataset._session
|
334
348
|
assert session is not None # keep mypy happy
|
335
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
336
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
337
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
338
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
339
|
-
|
340
349
|
# Specify input columns so column pruning will be enforced
|
341
350
|
selected_cols = self._get_active_columns()
|
342
351
|
if len(selected_cols) > 0:
|
@@ -364,7 +373,9 @@ class IterativeImputer(BaseTransformer):
|
|
364
373
|
label_cols=self.label_cols,
|
365
374
|
sample_weight_col=self.sample_weight_col,
|
366
375
|
autogenerated=self._autogenerated,
|
367
|
-
subproject=_SUBPROJECT
|
376
|
+
subproject=_SUBPROJECT,
|
377
|
+
use_external_memory_version=self._use_external_memory_version,
|
378
|
+
batch_size=self._batch_size,
|
368
379
|
)
|
369
380
|
self._sklearn_object = model_trainer.train()
|
370
381
|
self._is_fitted = True
|
@@ -635,6 +646,22 @@ class IterativeImputer(BaseTransformer):
|
|
635
646
|
# each row containing a list of values.
|
636
647
|
expected_dtype = "ARRAY"
|
637
648
|
|
649
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
650
|
+
if expected_dtype == "":
|
651
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
652
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
653
|
+
expected_dtype = "ARRAY"
|
654
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
655
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
656
|
+
expected_dtype = "ARRAY"
|
657
|
+
else:
|
658
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
659
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
660
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
661
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
662
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
663
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
664
|
+
|
638
665
|
output_df = self._batch_inference(
|
639
666
|
dataset=dataset,
|
640
667
|
inference_method="transform",
|
@@ -650,8 +677,8 @@ class IterativeImputer(BaseTransformer):
|
|
650
677
|
|
651
678
|
return output_df
|
652
679
|
|
653
|
-
@available_if(
|
654
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
680
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
681
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
655
682
|
""" Method not supported for this class.
|
656
683
|
|
657
684
|
|
@@ -664,13 +691,21 @@ class IterativeImputer(BaseTransformer):
|
|
664
691
|
Returns:
|
665
692
|
Predicted dataset.
|
666
693
|
"""
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
694
|
+
self.fit(dataset)
|
695
|
+
assert self._sklearn_object is not None
|
696
|
+
return self._sklearn_object.labels_
|
697
|
+
|
698
|
+
|
699
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
700
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
701
|
+
"""
|
702
|
+
Returns:
|
703
|
+
Transformed dataset.
|
704
|
+
"""
|
705
|
+
self.fit(dataset)
|
706
|
+
assert self._sklearn_object is not None
|
707
|
+
return self._sklearn_object.embedding_
|
708
|
+
|
674
709
|
|
675
710
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
676
711
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class KNNImputer(BaseTransformer):
|
58
70
|
r"""Imputation for completing missing values using k-Nearest Neighbors
|
59
71
|
For more details on this class, see [sklearn.impute.KNNImputer]
|
@@ -176,7 +188,9 @@ class KNNImputer(BaseTransformer):
|
|
176
188
|
self.set_label_cols(label_cols)
|
177
189
|
self.set_passthrough_cols(passthrough_cols)
|
178
190
|
self.set_drop_input_cols(drop_input_cols)
|
179
|
-
self.set_sample_weight_col(sample_weight_col)
|
191
|
+
self.set_sample_weight_col(sample_weight_col)
|
192
|
+
self._use_external_memory_version = False
|
193
|
+
self._batch_size = -1
|
180
194
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
181
195
|
|
182
196
|
self._deps = list(deps)
|
@@ -258,11 +272,6 @@ class KNNImputer(BaseTransformer):
|
|
258
272
|
if isinstance(dataset, DataFrame):
|
259
273
|
session = dataset._session
|
260
274
|
assert session is not None # keep mypy happy
|
261
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
262
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
263
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
264
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
265
|
-
|
266
275
|
# Specify input columns so column pruning will be enforced
|
267
276
|
selected_cols = self._get_active_columns()
|
268
277
|
if len(selected_cols) > 0:
|
@@ -290,7 +299,9 @@ class KNNImputer(BaseTransformer):
|
|
290
299
|
label_cols=self.label_cols,
|
291
300
|
sample_weight_col=self.sample_weight_col,
|
292
301
|
autogenerated=self._autogenerated,
|
293
|
-
subproject=_SUBPROJECT
|
302
|
+
subproject=_SUBPROJECT,
|
303
|
+
use_external_memory_version=self._use_external_memory_version,
|
304
|
+
batch_size=self._batch_size,
|
294
305
|
)
|
295
306
|
self._sklearn_object = model_trainer.train()
|
296
307
|
self._is_fitted = True
|
@@ -561,6 +572,22 @@ class KNNImputer(BaseTransformer):
|
|
561
572
|
# each row containing a list of values.
|
562
573
|
expected_dtype = "ARRAY"
|
563
574
|
|
575
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
576
|
+
if expected_dtype == "":
|
577
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
578
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
579
|
+
expected_dtype = "ARRAY"
|
580
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
581
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
582
|
+
expected_dtype = "ARRAY"
|
583
|
+
else:
|
584
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
585
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
586
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
587
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
588
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
589
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
590
|
+
|
564
591
|
output_df = self._batch_inference(
|
565
592
|
dataset=dataset,
|
566
593
|
inference_method="transform",
|
@@ -576,8 +603,8 @@ class KNNImputer(BaseTransformer):
|
|
576
603
|
|
577
604
|
return output_df
|
578
605
|
|
579
|
-
@available_if(
|
580
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
606
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
607
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
581
608
|
""" Method not supported for this class.
|
582
609
|
|
583
610
|
|
@@ -590,13 +617,21 @@ class KNNImputer(BaseTransformer):
|
|
590
617
|
Returns:
|
591
618
|
Predicted dataset.
|
592
619
|
"""
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
620
|
+
self.fit(dataset)
|
621
|
+
assert self._sklearn_object is not None
|
622
|
+
return self._sklearn_object.labels_
|
623
|
+
|
624
|
+
|
625
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
626
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
627
|
+
"""
|
628
|
+
Returns:
|
629
|
+
Transformed dataset.
|
630
|
+
"""
|
631
|
+
self.fit(dataset)
|
632
|
+
assert self._sklearn_object is not None
|
633
|
+
return self._sklearn_object.embedding_
|
634
|
+
|
600
635
|
|
601
636
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
602
637
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MissingIndicator(BaseTransformer):
|
58
70
|
r"""Binary indicators for missing values
|
59
71
|
For more details on this class, see [sklearn.impute.MissingIndicator]
|
@@ -153,7 +165,9 @@ class MissingIndicator(BaseTransformer):
|
|
153
165
|
self.set_label_cols(label_cols)
|
154
166
|
self.set_passthrough_cols(passthrough_cols)
|
155
167
|
self.set_drop_input_cols(drop_input_cols)
|
156
|
-
self.set_sample_weight_col(sample_weight_col)
|
168
|
+
self.set_sample_weight_col(sample_weight_col)
|
169
|
+
self._use_external_memory_version = False
|
170
|
+
self._batch_size = -1
|
157
171
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
158
172
|
|
159
173
|
self._deps = list(deps)
|
@@ -232,11 +246,6 @@ class MissingIndicator(BaseTransformer):
|
|
232
246
|
if isinstance(dataset, DataFrame):
|
233
247
|
session = dataset._session
|
234
248
|
assert session is not None # keep mypy happy
|
235
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
236
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
237
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
238
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
239
|
-
|
240
249
|
# Specify input columns so column pruning will be enforced
|
241
250
|
selected_cols = self._get_active_columns()
|
242
251
|
if len(selected_cols) > 0:
|
@@ -264,7 +273,9 @@ class MissingIndicator(BaseTransformer):
|
|
264
273
|
label_cols=self.label_cols,
|
265
274
|
sample_weight_col=self.sample_weight_col,
|
266
275
|
autogenerated=self._autogenerated,
|
267
|
-
subproject=_SUBPROJECT
|
276
|
+
subproject=_SUBPROJECT,
|
277
|
+
use_external_memory_version=self._use_external_memory_version,
|
278
|
+
batch_size=self._batch_size,
|
268
279
|
)
|
269
280
|
self._sklearn_object = model_trainer.train()
|
270
281
|
self._is_fitted = True
|
@@ -535,6 +546,22 @@ class MissingIndicator(BaseTransformer):
|
|
535
546
|
# each row containing a list of values.
|
536
547
|
expected_dtype = "ARRAY"
|
537
548
|
|
549
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
550
|
+
if expected_dtype == "":
|
551
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
552
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
553
|
+
expected_dtype = "ARRAY"
|
554
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
555
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
556
|
+
expected_dtype = "ARRAY"
|
557
|
+
else:
|
558
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
559
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
560
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
561
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
562
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
563
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
564
|
+
|
538
565
|
output_df = self._batch_inference(
|
539
566
|
dataset=dataset,
|
540
567
|
inference_method="transform",
|
@@ -550,8 +577,8 @@ class MissingIndicator(BaseTransformer):
|
|
550
577
|
|
551
578
|
return output_df
|
552
579
|
|
553
|
-
@available_if(
|
554
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
580
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
581
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
555
582
|
""" Method not supported for this class.
|
556
583
|
|
557
584
|
|
@@ -564,13 +591,21 @@ class MissingIndicator(BaseTransformer):
|
|
564
591
|
Returns:
|
565
592
|
Predicted dataset.
|
566
593
|
"""
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
594
|
+
self.fit(dataset)
|
595
|
+
assert self._sklearn_object is not None
|
596
|
+
return self._sklearn_object.labels_
|
597
|
+
|
598
|
+
|
599
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
600
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
601
|
+
"""
|
602
|
+
Returns:
|
603
|
+
Transformed dataset.
|
604
|
+
"""
|
605
|
+
self.fit(dataset)
|
606
|
+
assert self._sklearn_object is not None
|
607
|
+
return self._sklearn_object.embedding_
|
608
|
+
|
574
609
|
|
575
610
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
576
611
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class AdditiveChi2Sampler(BaseTransformer):
|
58
70
|
r"""Approximate feature map for additive chi2 kernel
|
59
71
|
For more details on this class, see [sklearn.kernel_approximation.AdditiveChi2Sampler]
|
@@ -130,7 +142,9 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
130
142
|
self.set_label_cols(label_cols)
|
131
143
|
self.set_passthrough_cols(passthrough_cols)
|
132
144
|
self.set_drop_input_cols(drop_input_cols)
|
133
|
-
self.set_sample_weight_col(sample_weight_col)
|
145
|
+
self.set_sample_weight_col(sample_weight_col)
|
146
|
+
self._use_external_memory_version = False
|
147
|
+
self._batch_size = -1
|
134
148
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
135
149
|
|
136
150
|
self._deps = list(deps)
|
@@ -207,11 +221,6 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
207
221
|
if isinstance(dataset, DataFrame):
|
208
222
|
session = dataset._session
|
209
223
|
assert session is not None # keep mypy happy
|
210
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
211
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
212
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
213
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
214
|
-
|
215
224
|
# Specify input columns so column pruning will be enforced
|
216
225
|
selected_cols = self._get_active_columns()
|
217
226
|
if len(selected_cols) > 0:
|
@@ -239,7 +248,9 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
239
248
|
label_cols=self.label_cols,
|
240
249
|
sample_weight_col=self.sample_weight_col,
|
241
250
|
autogenerated=self._autogenerated,
|
242
|
-
subproject=_SUBPROJECT
|
251
|
+
subproject=_SUBPROJECT,
|
252
|
+
use_external_memory_version=self._use_external_memory_version,
|
253
|
+
batch_size=self._batch_size,
|
243
254
|
)
|
244
255
|
self._sklearn_object = model_trainer.train()
|
245
256
|
self._is_fitted = True
|
@@ -510,6 +521,22 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
510
521
|
# each row containing a list of values.
|
511
522
|
expected_dtype = "ARRAY"
|
512
523
|
|
524
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
525
|
+
if expected_dtype == "":
|
526
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
527
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
528
|
+
expected_dtype = "ARRAY"
|
529
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
530
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
531
|
+
expected_dtype = "ARRAY"
|
532
|
+
else:
|
533
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
534
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
535
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
536
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
537
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
538
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
539
|
+
|
513
540
|
output_df = self._batch_inference(
|
514
541
|
dataset=dataset,
|
515
542
|
inference_method="transform",
|
@@ -525,8 +552,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
525
552
|
|
526
553
|
return output_df
|
527
554
|
|
528
|
-
@available_if(
|
529
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
555
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
556
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
530
557
|
""" Method not supported for this class.
|
531
558
|
|
532
559
|
|
@@ -539,13 +566,21 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
539
566
|
Returns:
|
540
567
|
Predicted dataset.
|
541
568
|
"""
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
569
|
+
self.fit(dataset)
|
570
|
+
assert self._sklearn_object is not None
|
571
|
+
return self._sklearn_object.labels_
|
572
|
+
|
573
|
+
|
574
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
575
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
576
|
+
"""
|
577
|
+
Returns:
|
578
|
+
Transformed dataset.
|
579
|
+
"""
|
580
|
+
self.fit(dataset)
|
581
|
+
assert self._sklearn_object is not None
|
582
|
+
return self._sklearn_object.embedding_
|
583
|
+
|
549
584
|
|
550
585
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
551
586
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class Nystroem(BaseTransformer):
|
58
70
|
r"""Approximate a kernel map using a subset of the training data
|
59
71
|
For more details on this class, see [sklearn.kernel_approximation.Nystroem]
|
@@ -172,7 +184,9 @@ class Nystroem(BaseTransformer):
|
|
172
184
|
self.set_label_cols(label_cols)
|
173
185
|
self.set_passthrough_cols(passthrough_cols)
|
174
186
|
self.set_drop_input_cols(drop_input_cols)
|
175
|
-
self.set_sample_weight_col(sample_weight_col)
|
187
|
+
self.set_sample_weight_col(sample_weight_col)
|
188
|
+
self._use_external_memory_version = False
|
189
|
+
self._batch_size = -1
|
176
190
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
177
191
|
|
178
192
|
self._deps = list(deps)
|
@@ -255,11 +269,6 @@ class Nystroem(BaseTransformer):
|
|
255
269
|
if isinstance(dataset, DataFrame):
|
256
270
|
session = dataset._session
|
257
271
|
assert session is not None # keep mypy happy
|
258
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
259
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
260
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
261
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
262
|
-
|
263
272
|
# Specify input columns so column pruning will be enforced
|
264
273
|
selected_cols = self._get_active_columns()
|
265
274
|
if len(selected_cols) > 0:
|
@@ -287,7 +296,9 @@ class Nystroem(BaseTransformer):
|
|
287
296
|
label_cols=self.label_cols,
|
288
297
|
sample_weight_col=self.sample_weight_col,
|
289
298
|
autogenerated=self._autogenerated,
|
290
|
-
subproject=_SUBPROJECT
|
299
|
+
subproject=_SUBPROJECT,
|
300
|
+
use_external_memory_version=self._use_external_memory_version,
|
301
|
+
batch_size=self._batch_size,
|
291
302
|
)
|
292
303
|
self._sklearn_object = model_trainer.train()
|
293
304
|
self._is_fitted = True
|
@@ -558,6 +569,22 @@ class Nystroem(BaseTransformer):
|
|
558
569
|
# each row containing a list of values.
|
559
570
|
expected_dtype = "ARRAY"
|
560
571
|
|
572
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
573
|
+
if expected_dtype == "":
|
574
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
575
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
576
|
+
expected_dtype = "ARRAY"
|
577
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
578
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
579
|
+
expected_dtype = "ARRAY"
|
580
|
+
else:
|
581
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
582
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
583
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
584
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
585
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
586
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
587
|
+
|
561
588
|
output_df = self._batch_inference(
|
562
589
|
dataset=dataset,
|
563
590
|
inference_method="transform",
|
@@ -573,8 +600,8 @@ class Nystroem(BaseTransformer):
|
|
573
600
|
|
574
601
|
return output_df
|
575
602
|
|
576
|
-
@available_if(
|
577
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
603
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
604
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
578
605
|
""" Method not supported for this class.
|
579
606
|
|
580
607
|
|
@@ -587,13 +614,21 @@ class Nystroem(BaseTransformer):
|
|
587
614
|
Returns:
|
588
615
|
Predicted dataset.
|
589
616
|
"""
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
617
|
+
self.fit(dataset)
|
618
|
+
assert self._sklearn_object is not None
|
619
|
+
return self._sklearn_object.labels_
|
620
|
+
|
621
|
+
|
622
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
623
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
624
|
+
"""
|
625
|
+
Returns:
|
626
|
+
Transformed dataset.
|
627
|
+
"""
|
628
|
+
self.fit(dataset)
|
629
|
+
assert self._sklearn_object is not None
|
630
|
+
return self._sklearn_object.embedding_
|
631
|
+
|
597
632
|
|
598
633
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
599
634
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|