snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class PoissonRegressor(BaseTransformer):
|
58
70
|
r"""Generalized Linear Model with a Poisson distribution
|
59
71
|
For more details on this class, see [sklearn.linear_model.PoissonRegressor]
|
@@ -175,7 +187,9 @@ class PoissonRegressor(BaseTransformer):
|
|
175
187
|
self.set_label_cols(label_cols)
|
176
188
|
self.set_passthrough_cols(passthrough_cols)
|
177
189
|
self.set_drop_input_cols(drop_input_cols)
|
178
|
-
self.set_sample_weight_col(sample_weight_col)
|
190
|
+
self.set_sample_weight_col(sample_weight_col)
|
191
|
+
self._use_external_memory_version = False
|
192
|
+
self._batch_size = -1
|
179
193
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
180
194
|
|
181
195
|
self._deps = list(deps)
|
@@ -257,11 +271,6 @@ class PoissonRegressor(BaseTransformer):
|
|
257
271
|
if isinstance(dataset, DataFrame):
|
258
272
|
session = dataset._session
|
259
273
|
assert session is not None # keep mypy happy
|
260
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
261
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
262
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
263
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
264
|
-
|
265
274
|
# Specify input columns so column pruning will be enforced
|
266
275
|
selected_cols = self._get_active_columns()
|
267
276
|
if len(selected_cols) > 0:
|
@@ -289,7 +298,9 @@ class PoissonRegressor(BaseTransformer):
|
|
289
298
|
label_cols=self.label_cols,
|
290
299
|
sample_weight_col=self.sample_weight_col,
|
291
300
|
autogenerated=self._autogenerated,
|
292
|
-
subproject=_SUBPROJECT
|
301
|
+
subproject=_SUBPROJECT,
|
302
|
+
use_external_memory_version=self._use_external_memory_version,
|
303
|
+
batch_size=self._batch_size,
|
293
304
|
)
|
294
305
|
self._sklearn_object = model_trainer.train()
|
295
306
|
self._is_fitted = True
|
@@ -560,6 +571,22 @@ class PoissonRegressor(BaseTransformer):
|
|
560
571
|
# each row containing a list of values.
|
561
572
|
expected_dtype = "ARRAY"
|
562
573
|
|
574
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
575
|
+
if expected_dtype == "":
|
576
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
577
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
578
|
+
expected_dtype = "ARRAY"
|
579
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
580
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
581
|
+
expected_dtype = "ARRAY"
|
582
|
+
else:
|
583
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
584
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
585
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
586
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
587
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
588
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
589
|
+
|
563
590
|
output_df = self._batch_inference(
|
564
591
|
dataset=dataset,
|
565
592
|
inference_method="transform",
|
@@ -575,8 +602,8 @@ class PoissonRegressor(BaseTransformer):
|
|
575
602
|
|
576
603
|
return output_df
|
577
604
|
|
578
|
-
@available_if(
|
579
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
605
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
606
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
580
607
|
""" Method not supported for this class.
|
581
608
|
|
582
609
|
|
@@ -589,13 +616,21 @@ class PoissonRegressor(BaseTransformer):
|
|
589
616
|
Returns:
|
590
617
|
Predicted dataset.
|
591
618
|
"""
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
619
|
+
self.fit(dataset)
|
620
|
+
assert self._sklearn_object is not None
|
621
|
+
return self._sklearn_object.labels_
|
622
|
+
|
623
|
+
|
624
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
625
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
626
|
+
"""
|
627
|
+
Returns:
|
628
|
+
Transformed dataset.
|
629
|
+
"""
|
630
|
+
self.fit(dataset)
|
631
|
+
assert self._sklearn_object is not None
|
632
|
+
return self._sklearn_object.embedding_
|
633
|
+
|
599
634
|
|
600
635
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
601
636
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RANSACRegressor(BaseTransformer):
|
58
70
|
r"""RANSAC (RANdom SAmple Consensus) algorithm
|
59
71
|
For more details on this class, see [sklearn.linear_model.RANSACRegressor]
|
@@ -226,7 +238,9 @@ class RANSACRegressor(BaseTransformer):
|
|
226
238
|
self.set_label_cols(label_cols)
|
227
239
|
self.set_passthrough_cols(passthrough_cols)
|
228
240
|
self.set_drop_input_cols(drop_input_cols)
|
229
|
-
self.set_sample_weight_col(sample_weight_col)
|
241
|
+
self.set_sample_weight_col(sample_weight_col)
|
242
|
+
self._use_external_memory_version = False
|
243
|
+
self._batch_size = -1
|
230
244
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
231
245
|
deps = deps | gather_dependencies(estimator)
|
232
246
|
self._deps = list(deps)
|
@@ -313,11 +327,6 @@ class RANSACRegressor(BaseTransformer):
|
|
313
327
|
if isinstance(dataset, DataFrame):
|
314
328
|
session = dataset._session
|
315
329
|
assert session is not None # keep mypy happy
|
316
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
317
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
318
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
319
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
320
|
-
|
321
330
|
# Specify input columns so column pruning will be enforced
|
322
331
|
selected_cols = self._get_active_columns()
|
323
332
|
if len(selected_cols) > 0:
|
@@ -345,7 +354,9 @@ class RANSACRegressor(BaseTransformer):
|
|
345
354
|
label_cols=self.label_cols,
|
346
355
|
sample_weight_col=self.sample_weight_col,
|
347
356
|
autogenerated=self._autogenerated,
|
348
|
-
subproject=_SUBPROJECT
|
357
|
+
subproject=_SUBPROJECT,
|
358
|
+
use_external_memory_version=self._use_external_memory_version,
|
359
|
+
batch_size=self._batch_size,
|
349
360
|
)
|
350
361
|
self._sklearn_object = model_trainer.train()
|
351
362
|
self._is_fitted = True
|
@@ -616,6 +627,22 @@ class RANSACRegressor(BaseTransformer):
|
|
616
627
|
# each row containing a list of values.
|
617
628
|
expected_dtype = "ARRAY"
|
618
629
|
|
630
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
631
|
+
if expected_dtype == "":
|
632
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
633
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
634
|
+
expected_dtype = "ARRAY"
|
635
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
636
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
637
|
+
expected_dtype = "ARRAY"
|
638
|
+
else:
|
639
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
640
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
641
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
642
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
643
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
644
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
645
|
+
|
619
646
|
output_df = self._batch_inference(
|
620
647
|
dataset=dataset,
|
621
648
|
inference_method="transform",
|
@@ -631,8 +658,8 @@ class RANSACRegressor(BaseTransformer):
|
|
631
658
|
|
632
659
|
return output_df
|
633
660
|
|
634
|
-
@available_if(
|
635
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
661
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
662
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
636
663
|
""" Method not supported for this class.
|
637
664
|
|
638
665
|
|
@@ -645,13 +672,21 @@ class RANSACRegressor(BaseTransformer):
|
|
645
672
|
Returns:
|
646
673
|
Predicted dataset.
|
647
674
|
"""
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
675
|
+
self.fit(dataset)
|
676
|
+
assert self._sklearn_object is not None
|
677
|
+
return self._sklearn_object.labels_
|
678
|
+
|
679
|
+
|
680
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
681
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
682
|
+
"""
|
683
|
+
Returns:
|
684
|
+
Transformed dataset.
|
685
|
+
"""
|
686
|
+
self.fit(dataset)
|
687
|
+
assert self._sklearn_object is not None
|
688
|
+
return self._sklearn_object.embedding_
|
689
|
+
|
655
690
|
|
656
691
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
657
692
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class Ridge(BaseTransformer):
|
58
70
|
r"""Linear least squares with l2 regularization
|
59
71
|
For more details on this class, see [sklearn.linear_model.Ridge]
|
@@ -222,7 +234,9 @@ class Ridge(BaseTransformer):
|
|
222
234
|
self.set_label_cols(label_cols)
|
223
235
|
self.set_passthrough_cols(passthrough_cols)
|
224
236
|
self.set_drop_input_cols(drop_input_cols)
|
225
|
-
self.set_sample_weight_col(sample_weight_col)
|
237
|
+
self.set_sample_weight_col(sample_weight_col)
|
238
|
+
self._use_external_memory_version = False
|
239
|
+
self._batch_size = -1
|
226
240
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
227
241
|
|
228
242
|
self._deps = list(deps)
|
@@ -305,11 +319,6 @@ class Ridge(BaseTransformer):
|
|
305
319
|
if isinstance(dataset, DataFrame):
|
306
320
|
session = dataset._session
|
307
321
|
assert session is not None # keep mypy happy
|
308
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
309
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
310
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
311
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
312
|
-
|
313
322
|
# Specify input columns so column pruning will be enforced
|
314
323
|
selected_cols = self._get_active_columns()
|
315
324
|
if len(selected_cols) > 0:
|
@@ -337,7 +346,9 @@ class Ridge(BaseTransformer):
|
|
337
346
|
label_cols=self.label_cols,
|
338
347
|
sample_weight_col=self.sample_weight_col,
|
339
348
|
autogenerated=self._autogenerated,
|
340
|
-
subproject=_SUBPROJECT
|
349
|
+
subproject=_SUBPROJECT,
|
350
|
+
use_external_memory_version=self._use_external_memory_version,
|
351
|
+
batch_size=self._batch_size,
|
341
352
|
)
|
342
353
|
self._sklearn_object = model_trainer.train()
|
343
354
|
self._is_fitted = True
|
@@ -608,6 +619,22 @@ class Ridge(BaseTransformer):
|
|
608
619
|
# each row containing a list of values.
|
609
620
|
expected_dtype = "ARRAY"
|
610
621
|
|
622
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
623
|
+
if expected_dtype == "":
|
624
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
625
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
626
|
+
expected_dtype = "ARRAY"
|
627
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
628
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
629
|
+
expected_dtype = "ARRAY"
|
630
|
+
else:
|
631
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
632
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
633
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
634
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
635
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
636
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
637
|
+
|
611
638
|
output_df = self._batch_inference(
|
612
639
|
dataset=dataset,
|
613
640
|
inference_method="transform",
|
@@ -623,8 +650,8 @@ class Ridge(BaseTransformer):
|
|
623
650
|
|
624
651
|
return output_df
|
625
652
|
|
626
|
-
@available_if(
|
627
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
653
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
654
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
628
655
|
""" Method not supported for this class.
|
629
656
|
|
630
657
|
|
@@ -637,13 +664,21 @@ class Ridge(BaseTransformer):
|
|
637
664
|
Returns:
|
638
665
|
Predicted dataset.
|
639
666
|
"""
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
667
|
+
self.fit(dataset)
|
668
|
+
assert self._sklearn_object is not None
|
669
|
+
return self._sklearn_object.labels_
|
670
|
+
|
671
|
+
|
672
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
673
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
674
|
+
"""
|
675
|
+
Returns:
|
676
|
+
Transformed dataset.
|
677
|
+
"""
|
678
|
+
self.fit(dataset)
|
679
|
+
assert self._sklearn_object is not None
|
680
|
+
return self._sklearn_object.embedding_
|
681
|
+
|
647
682
|
|
648
683
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
649
684
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RidgeClassifier(BaseTransformer):
|
58
70
|
r"""Classifier using Ridge regression
|
59
71
|
For more details on this class, see [sklearn.linear_model.RidgeClassifier]
|
@@ -221,7 +233,9 @@ class RidgeClassifier(BaseTransformer):
|
|
221
233
|
self.set_label_cols(label_cols)
|
222
234
|
self.set_passthrough_cols(passthrough_cols)
|
223
235
|
self.set_drop_input_cols(drop_input_cols)
|
224
|
-
self.set_sample_weight_col(sample_weight_col)
|
236
|
+
self.set_sample_weight_col(sample_weight_col)
|
237
|
+
self._use_external_memory_version = False
|
238
|
+
self._batch_size = -1
|
225
239
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
226
240
|
|
227
241
|
self._deps = list(deps)
|
@@ -305,11 +319,6 @@ class RidgeClassifier(BaseTransformer):
|
|
305
319
|
if isinstance(dataset, DataFrame):
|
306
320
|
session = dataset._session
|
307
321
|
assert session is not None # keep mypy happy
|
308
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
309
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
310
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
311
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
312
|
-
|
313
322
|
# Specify input columns so column pruning will be enforced
|
314
323
|
selected_cols = self._get_active_columns()
|
315
324
|
if len(selected_cols) > 0:
|
@@ -337,7 +346,9 @@ class RidgeClassifier(BaseTransformer):
|
|
337
346
|
label_cols=self.label_cols,
|
338
347
|
sample_weight_col=self.sample_weight_col,
|
339
348
|
autogenerated=self._autogenerated,
|
340
|
-
subproject=_SUBPROJECT
|
349
|
+
subproject=_SUBPROJECT,
|
350
|
+
use_external_memory_version=self._use_external_memory_version,
|
351
|
+
batch_size=self._batch_size,
|
341
352
|
)
|
342
353
|
self._sklearn_object = model_trainer.train()
|
343
354
|
self._is_fitted = True
|
@@ -608,6 +619,22 @@ class RidgeClassifier(BaseTransformer):
|
|
608
619
|
# each row containing a list of values.
|
609
620
|
expected_dtype = "ARRAY"
|
610
621
|
|
622
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
623
|
+
if expected_dtype == "":
|
624
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
625
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
626
|
+
expected_dtype = "ARRAY"
|
627
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
628
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
629
|
+
expected_dtype = "ARRAY"
|
630
|
+
else:
|
631
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
632
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
633
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
634
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
635
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
636
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
637
|
+
|
611
638
|
output_df = self._batch_inference(
|
612
639
|
dataset=dataset,
|
613
640
|
inference_method="transform",
|
@@ -623,8 +650,8 @@ class RidgeClassifier(BaseTransformer):
|
|
623
650
|
|
624
651
|
return output_df
|
625
652
|
|
626
|
-
@available_if(
|
627
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
653
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
654
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
628
655
|
""" Method not supported for this class.
|
629
656
|
|
630
657
|
|
@@ -637,13 +664,21 @@ class RidgeClassifier(BaseTransformer):
|
|
637
664
|
Returns:
|
638
665
|
Predicted dataset.
|
639
666
|
"""
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
667
|
+
self.fit(dataset)
|
668
|
+
assert self._sklearn_object is not None
|
669
|
+
return self._sklearn_object.labels_
|
670
|
+
|
671
|
+
|
672
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
673
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
674
|
+
"""
|
675
|
+
Returns:
|
676
|
+
Transformed dataset.
|
677
|
+
"""
|
678
|
+
self.fit(dataset)
|
679
|
+
assert self._sklearn_object is not None
|
680
|
+
return self._sklearn_object.embedding_
|
681
|
+
|
647
682
|
|
648
683
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
649
684
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RidgeClassifierCV(BaseTransformer):
|
58
70
|
r"""Ridge classifier with built-in cross-validation
|
59
71
|
For more details on this class, see [sklearn.linear_model.RidgeClassifierCV]
|
@@ -175,7 +187,9 @@ class RidgeClassifierCV(BaseTransformer):
|
|
175
187
|
self.set_label_cols(label_cols)
|
176
188
|
self.set_passthrough_cols(passthrough_cols)
|
177
189
|
self.set_drop_input_cols(drop_input_cols)
|
178
|
-
self.set_sample_weight_col(sample_weight_col)
|
190
|
+
self.set_sample_weight_col(sample_weight_col)
|
191
|
+
self._use_external_memory_version = False
|
192
|
+
self._batch_size = -1
|
179
193
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
180
194
|
|
181
195
|
self._deps = list(deps)
|
@@ -256,11 +270,6 @@ class RidgeClassifierCV(BaseTransformer):
|
|
256
270
|
if isinstance(dataset, DataFrame):
|
257
271
|
session = dataset._session
|
258
272
|
assert session is not None # keep mypy happy
|
259
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
260
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
261
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
262
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
263
|
-
|
264
273
|
# Specify input columns so column pruning will be enforced
|
265
274
|
selected_cols = self._get_active_columns()
|
266
275
|
if len(selected_cols) > 0:
|
@@ -288,7 +297,9 @@ class RidgeClassifierCV(BaseTransformer):
|
|
288
297
|
label_cols=self.label_cols,
|
289
298
|
sample_weight_col=self.sample_weight_col,
|
290
299
|
autogenerated=self._autogenerated,
|
291
|
-
subproject=_SUBPROJECT
|
300
|
+
subproject=_SUBPROJECT,
|
301
|
+
use_external_memory_version=self._use_external_memory_version,
|
302
|
+
batch_size=self._batch_size,
|
292
303
|
)
|
293
304
|
self._sklearn_object = model_trainer.train()
|
294
305
|
self._is_fitted = True
|
@@ -559,6 +570,22 @@ class RidgeClassifierCV(BaseTransformer):
|
|
559
570
|
# each row containing a list of values.
|
560
571
|
expected_dtype = "ARRAY"
|
561
572
|
|
573
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
574
|
+
if expected_dtype == "":
|
575
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
576
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
577
|
+
expected_dtype = "ARRAY"
|
578
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
579
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
580
|
+
expected_dtype = "ARRAY"
|
581
|
+
else:
|
582
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
583
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
584
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
585
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
586
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
587
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
588
|
+
|
562
589
|
output_df = self._batch_inference(
|
563
590
|
dataset=dataset,
|
564
591
|
inference_method="transform",
|
@@ -574,8 +601,8 @@ class RidgeClassifierCV(BaseTransformer):
|
|
574
601
|
|
575
602
|
return output_df
|
576
603
|
|
577
|
-
@available_if(
|
578
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
604
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
605
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
579
606
|
""" Method not supported for this class.
|
580
607
|
|
581
608
|
|
@@ -588,13 +615,21 @@ class RidgeClassifierCV(BaseTransformer):
|
|
588
615
|
Returns:
|
589
616
|
Predicted dataset.
|
590
617
|
"""
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
618
|
+
self.fit(dataset)
|
619
|
+
assert self._sklearn_object is not None
|
620
|
+
return self._sklearn_object.labels_
|
621
|
+
|
622
|
+
|
623
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
624
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
625
|
+
"""
|
626
|
+
Returns:
|
627
|
+
Transformed dataset.
|
628
|
+
"""
|
629
|
+
self.fit(dataset)
|
630
|
+
assert self._sklearn_object is not None
|
631
|
+
return self._sklearn_object.embedding_
|
632
|
+
|
598
633
|
|
599
634
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
600
635
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|