snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -53,6 +53,18 @@ _PROJECT = "ModelDevelopment"
|
|
53
53
|
_SUBPROJECT = "".join([s.capitalize() for s in "lightgbm".replace("sklearn.", "").split("_")])
|
54
54
|
|
55
55
|
|
56
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
57
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
58
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
59
|
+
return check
|
60
|
+
|
61
|
+
|
62
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
63
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
64
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
65
|
+
return check
|
66
|
+
|
67
|
+
|
56
68
|
class LGBMRegressor(BaseTransformer):
|
57
69
|
r"""LightGBM regressor
|
58
70
|
For more details on this class, see [lightgbm.LGBMRegressor]
|
@@ -144,7 +156,9 @@ class LGBMRegressor(BaseTransformer):
|
|
144
156
|
self.set_label_cols(label_cols)
|
145
157
|
self.set_passthrough_cols(passthrough_cols)
|
146
158
|
self.set_drop_input_cols(drop_input_cols)
|
147
|
-
self.set_sample_weight_col(sample_weight_col)
|
159
|
+
self.set_sample_weight_col(sample_weight_col)
|
160
|
+
self._use_external_memory_version = False
|
161
|
+
self._batch_size = -1
|
148
162
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'lightgbm=={lightgbm.__version__}', f'cloudpickle=={cp.__version__}'])
|
149
163
|
|
150
164
|
self._deps = list(deps)
|
@@ -240,11 +254,6 @@ class LGBMRegressor(BaseTransformer):
|
|
240
254
|
if isinstance(dataset, DataFrame):
|
241
255
|
session = dataset._session
|
242
256
|
assert session is not None # keep mypy happy
|
243
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
244
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
245
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
246
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
247
|
-
|
248
257
|
# Specify input columns so column pruning will be enforced
|
249
258
|
selected_cols = self._get_active_columns()
|
250
259
|
if len(selected_cols) > 0:
|
@@ -272,7 +281,9 @@ class LGBMRegressor(BaseTransformer):
|
|
272
281
|
label_cols=self.label_cols,
|
273
282
|
sample_weight_col=self.sample_weight_col,
|
274
283
|
autogenerated=self._autogenerated,
|
275
|
-
subproject=_SUBPROJECT
|
284
|
+
subproject=_SUBPROJECT,
|
285
|
+
use_external_memory_version=self._use_external_memory_version,
|
286
|
+
batch_size=self._batch_size,
|
276
287
|
)
|
277
288
|
self._sklearn_object = model_trainer.train()
|
278
289
|
self._is_fitted = True
|
@@ -543,6 +554,22 @@ class LGBMRegressor(BaseTransformer):
|
|
543
554
|
# each row containing a list of values.
|
544
555
|
expected_dtype = "ARRAY"
|
545
556
|
|
557
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
558
|
+
if expected_dtype == "":
|
559
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
560
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
561
|
+
expected_dtype = "ARRAY"
|
562
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
563
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
564
|
+
expected_dtype = "ARRAY"
|
565
|
+
else:
|
566
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
567
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
568
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
569
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
570
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
571
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
572
|
+
|
546
573
|
output_df = self._batch_inference(
|
547
574
|
dataset=dataset,
|
548
575
|
inference_method="transform",
|
@@ -558,8 +585,8 @@ class LGBMRegressor(BaseTransformer):
|
|
558
585
|
|
559
586
|
return output_df
|
560
587
|
|
561
|
-
@available_if(
|
562
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
588
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
589
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
563
590
|
""" Method not supported for this class.
|
564
591
|
|
565
592
|
|
@@ -572,13 +599,21 @@ class LGBMRegressor(BaseTransformer):
|
|
572
599
|
Returns:
|
573
600
|
Predicted dataset.
|
574
601
|
"""
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
602
|
+
self.fit(dataset)
|
603
|
+
assert self._sklearn_object is not None
|
604
|
+
return self._sklearn_object.labels_
|
605
|
+
|
606
|
+
|
607
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
608
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
609
|
+
"""
|
610
|
+
Returns:
|
611
|
+
Transformed dataset.
|
612
|
+
"""
|
613
|
+
self.fit(dataset)
|
614
|
+
assert self._sklearn_object is not None
|
615
|
+
return self._sklearn_object.embedding_
|
616
|
+
|
582
617
|
|
583
618
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
584
619
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ARDRegression(BaseTransformer):
|
58
70
|
r"""Bayesian ARD regression
|
59
71
|
For more details on this class, see [sklearn.linear_model.ARDRegression]
|
@@ -179,7 +191,9 @@ class ARDRegression(BaseTransformer):
|
|
179
191
|
self.set_label_cols(label_cols)
|
180
192
|
self.set_passthrough_cols(passthrough_cols)
|
181
193
|
self.set_drop_input_cols(drop_input_cols)
|
182
|
-
self.set_sample_weight_col(sample_weight_col)
|
194
|
+
self.set_sample_weight_col(sample_weight_col)
|
195
|
+
self._use_external_memory_version = False
|
196
|
+
self._batch_size = -1
|
183
197
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
184
198
|
|
185
199
|
self._deps = list(deps)
|
@@ -266,11 +280,6 @@ class ARDRegression(BaseTransformer):
|
|
266
280
|
if isinstance(dataset, DataFrame):
|
267
281
|
session = dataset._session
|
268
282
|
assert session is not None # keep mypy happy
|
269
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
270
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
271
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
272
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
273
|
-
|
274
283
|
# Specify input columns so column pruning will be enforced
|
275
284
|
selected_cols = self._get_active_columns()
|
276
285
|
if len(selected_cols) > 0:
|
@@ -298,7 +307,9 @@ class ARDRegression(BaseTransformer):
|
|
298
307
|
label_cols=self.label_cols,
|
299
308
|
sample_weight_col=self.sample_weight_col,
|
300
309
|
autogenerated=self._autogenerated,
|
301
|
-
subproject=_SUBPROJECT
|
310
|
+
subproject=_SUBPROJECT,
|
311
|
+
use_external_memory_version=self._use_external_memory_version,
|
312
|
+
batch_size=self._batch_size,
|
302
313
|
)
|
303
314
|
self._sklearn_object = model_trainer.train()
|
304
315
|
self._is_fitted = True
|
@@ -569,6 +580,22 @@ class ARDRegression(BaseTransformer):
|
|
569
580
|
# each row containing a list of values.
|
570
581
|
expected_dtype = "ARRAY"
|
571
582
|
|
583
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
584
|
+
if expected_dtype == "":
|
585
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
586
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
587
|
+
expected_dtype = "ARRAY"
|
588
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
589
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
590
|
+
expected_dtype = "ARRAY"
|
591
|
+
else:
|
592
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
593
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
594
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
595
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
596
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
597
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
598
|
+
|
572
599
|
output_df = self._batch_inference(
|
573
600
|
dataset=dataset,
|
574
601
|
inference_method="transform",
|
@@ -584,8 +611,8 @@ class ARDRegression(BaseTransformer):
|
|
584
611
|
|
585
612
|
return output_df
|
586
613
|
|
587
|
-
@available_if(
|
588
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
614
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
615
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
589
616
|
""" Method not supported for this class.
|
590
617
|
|
591
618
|
|
@@ -598,13 +625,21 @@ class ARDRegression(BaseTransformer):
|
|
598
625
|
Returns:
|
599
626
|
Predicted dataset.
|
600
627
|
"""
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
628
|
+
self.fit(dataset)
|
629
|
+
assert self._sklearn_object is not None
|
630
|
+
return self._sklearn_object.labels_
|
631
|
+
|
632
|
+
|
633
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
634
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
635
|
+
"""
|
636
|
+
Returns:
|
637
|
+
Transformed dataset.
|
638
|
+
"""
|
639
|
+
self.fit(dataset)
|
640
|
+
assert self._sklearn_object is not None
|
641
|
+
return self._sklearn_object.embedding_
|
642
|
+
|
608
643
|
|
609
644
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
610
645
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class BayesianRidge(BaseTransformer):
|
58
70
|
r"""Bayesian ridge regression
|
59
71
|
For more details on this class, see [sklearn.linear_model.BayesianRidge]
|
@@ -189,7 +201,9 @@ class BayesianRidge(BaseTransformer):
|
|
189
201
|
self.set_label_cols(label_cols)
|
190
202
|
self.set_passthrough_cols(passthrough_cols)
|
191
203
|
self.set_drop_input_cols(drop_input_cols)
|
192
|
-
self.set_sample_weight_col(sample_weight_col)
|
204
|
+
self.set_sample_weight_col(sample_weight_col)
|
205
|
+
self._use_external_memory_version = False
|
206
|
+
self._batch_size = -1
|
193
207
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
194
208
|
|
195
209
|
self._deps = list(deps)
|
@@ -277,11 +291,6 @@ class BayesianRidge(BaseTransformer):
|
|
277
291
|
if isinstance(dataset, DataFrame):
|
278
292
|
session = dataset._session
|
279
293
|
assert session is not None # keep mypy happy
|
280
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
281
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
282
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
283
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
284
|
-
|
285
294
|
# Specify input columns so column pruning will be enforced
|
286
295
|
selected_cols = self._get_active_columns()
|
287
296
|
if len(selected_cols) > 0:
|
@@ -309,7 +318,9 @@ class BayesianRidge(BaseTransformer):
|
|
309
318
|
label_cols=self.label_cols,
|
310
319
|
sample_weight_col=self.sample_weight_col,
|
311
320
|
autogenerated=self._autogenerated,
|
312
|
-
subproject=_SUBPROJECT
|
321
|
+
subproject=_SUBPROJECT,
|
322
|
+
use_external_memory_version=self._use_external_memory_version,
|
323
|
+
batch_size=self._batch_size,
|
313
324
|
)
|
314
325
|
self._sklearn_object = model_trainer.train()
|
315
326
|
self._is_fitted = True
|
@@ -580,6 +591,22 @@ class BayesianRidge(BaseTransformer):
|
|
580
591
|
# each row containing a list of values.
|
581
592
|
expected_dtype = "ARRAY"
|
582
593
|
|
594
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
595
|
+
if expected_dtype == "":
|
596
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
597
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
598
|
+
expected_dtype = "ARRAY"
|
599
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
600
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
601
|
+
expected_dtype = "ARRAY"
|
602
|
+
else:
|
603
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
604
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
605
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
606
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
607
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
608
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
609
|
+
|
583
610
|
output_df = self._batch_inference(
|
584
611
|
dataset=dataset,
|
585
612
|
inference_method="transform",
|
@@ -595,8 +622,8 @@ class BayesianRidge(BaseTransformer):
|
|
595
622
|
|
596
623
|
return output_df
|
597
624
|
|
598
|
-
@available_if(
|
599
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
625
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
626
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
600
627
|
""" Method not supported for this class.
|
601
628
|
|
602
629
|
|
@@ -609,13 +636,21 @@ class BayesianRidge(BaseTransformer):
|
|
609
636
|
Returns:
|
610
637
|
Predicted dataset.
|
611
638
|
"""
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
639
|
+
self.fit(dataset)
|
640
|
+
assert self._sklearn_object is not None
|
641
|
+
return self._sklearn_object.labels_
|
642
|
+
|
643
|
+
|
644
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
645
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
646
|
+
"""
|
647
|
+
Returns:
|
648
|
+
Transformed dataset.
|
649
|
+
"""
|
650
|
+
self.fit(dataset)
|
651
|
+
assert self._sklearn_object is not None
|
652
|
+
return self._sklearn_object.embedding_
|
653
|
+
|
619
654
|
|
620
655
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
621
656
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ElasticNet(BaseTransformer):
|
58
70
|
r"""Linear regression with combined L1 and L2 priors as regularizer
|
59
71
|
For more details on this class, see [sklearn.linear_model.ElasticNet]
|
@@ -190,7 +202,9 @@ class ElasticNet(BaseTransformer):
|
|
190
202
|
self.set_label_cols(label_cols)
|
191
203
|
self.set_passthrough_cols(passthrough_cols)
|
192
204
|
self.set_drop_input_cols(drop_input_cols)
|
193
|
-
self.set_sample_weight_col(sample_weight_col)
|
205
|
+
self.set_sample_weight_col(sample_weight_col)
|
206
|
+
self._use_external_memory_version = False
|
207
|
+
self._batch_size = -1
|
194
208
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
195
209
|
|
196
210
|
self._deps = list(deps)
|
@@ -276,11 +290,6 @@ class ElasticNet(BaseTransformer):
|
|
276
290
|
if isinstance(dataset, DataFrame):
|
277
291
|
session = dataset._session
|
278
292
|
assert session is not None # keep mypy happy
|
279
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
280
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
281
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
282
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
283
|
-
|
284
293
|
# Specify input columns so column pruning will be enforced
|
285
294
|
selected_cols = self._get_active_columns()
|
286
295
|
if len(selected_cols) > 0:
|
@@ -308,7 +317,9 @@ class ElasticNet(BaseTransformer):
|
|
308
317
|
label_cols=self.label_cols,
|
309
318
|
sample_weight_col=self.sample_weight_col,
|
310
319
|
autogenerated=self._autogenerated,
|
311
|
-
subproject=_SUBPROJECT
|
320
|
+
subproject=_SUBPROJECT,
|
321
|
+
use_external_memory_version=self._use_external_memory_version,
|
322
|
+
batch_size=self._batch_size,
|
312
323
|
)
|
313
324
|
self._sklearn_object = model_trainer.train()
|
314
325
|
self._is_fitted = True
|
@@ -579,6 +590,22 @@ class ElasticNet(BaseTransformer):
|
|
579
590
|
# each row containing a list of values.
|
580
591
|
expected_dtype = "ARRAY"
|
581
592
|
|
593
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
594
|
+
if expected_dtype == "":
|
595
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
596
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
597
|
+
expected_dtype = "ARRAY"
|
598
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
599
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
600
|
+
expected_dtype = "ARRAY"
|
601
|
+
else:
|
602
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
603
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
604
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
605
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
606
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
607
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
608
|
+
|
582
609
|
output_df = self._batch_inference(
|
583
610
|
dataset=dataset,
|
584
611
|
inference_method="transform",
|
@@ -594,8 +621,8 @@ class ElasticNet(BaseTransformer):
|
|
594
621
|
|
595
622
|
return output_df
|
596
623
|
|
597
|
-
@available_if(
|
598
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
624
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
625
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
599
626
|
""" Method not supported for this class.
|
600
627
|
|
601
628
|
|
@@ -608,13 +635,21 @@ class ElasticNet(BaseTransformer):
|
|
608
635
|
Returns:
|
609
636
|
Predicted dataset.
|
610
637
|
"""
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
638
|
+
self.fit(dataset)
|
639
|
+
assert self._sklearn_object is not None
|
640
|
+
return self._sklearn_object.labels_
|
641
|
+
|
642
|
+
|
643
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
644
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
645
|
+
"""
|
646
|
+
Returns:
|
647
|
+
Transformed dataset.
|
648
|
+
"""
|
649
|
+
self.fit(dataset)
|
650
|
+
assert self._sklearn_object is not None
|
651
|
+
return self._sklearn_object.embedding_
|
652
|
+
|
618
653
|
|
619
654
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
620
655
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ElasticNetCV(BaseTransformer):
|
58
70
|
r"""Elastic Net model with iterative fitting along a regularization path
|
59
71
|
For more details on this class, see [sklearn.linear_model.ElasticNetCV]
|
@@ -222,7 +234,9 @@ class ElasticNetCV(BaseTransformer):
|
|
222
234
|
self.set_label_cols(label_cols)
|
223
235
|
self.set_passthrough_cols(passthrough_cols)
|
224
236
|
self.set_drop_input_cols(drop_input_cols)
|
225
|
-
self.set_sample_weight_col(sample_weight_col)
|
237
|
+
self.set_sample_weight_col(sample_weight_col)
|
238
|
+
self._use_external_memory_version = False
|
239
|
+
self._batch_size = -1
|
226
240
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
227
241
|
|
228
242
|
self._deps = list(deps)
|
@@ -312,11 +326,6 @@ class ElasticNetCV(BaseTransformer):
|
|
312
326
|
if isinstance(dataset, DataFrame):
|
313
327
|
session = dataset._session
|
314
328
|
assert session is not None # keep mypy happy
|
315
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
316
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
317
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
318
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
319
|
-
|
320
329
|
# Specify input columns so column pruning will be enforced
|
321
330
|
selected_cols = self._get_active_columns()
|
322
331
|
if len(selected_cols) > 0:
|
@@ -344,7 +353,9 @@ class ElasticNetCV(BaseTransformer):
|
|
344
353
|
label_cols=self.label_cols,
|
345
354
|
sample_weight_col=self.sample_weight_col,
|
346
355
|
autogenerated=self._autogenerated,
|
347
|
-
subproject=_SUBPROJECT
|
356
|
+
subproject=_SUBPROJECT,
|
357
|
+
use_external_memory_version=self._use_external_memory_version,
|
358
|
+
batch_size=self._batch_size,
|
348
359
|
)
|
349
360
|
self._sklearn_object = model_trainer.train()
|
350
361
|
self._is_fitted = True
|
@@ -615,6 +626,22 @@ class ElasticNetCV(BaseTransformer):
|
|
615
626
|
# each row containing a list of values.
|
616
627
|
expected_dtype = "ARRAY"
|
617
628
|
|
629
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
630
|
+
if expected_dtype == "":
|
631
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
632
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
633
|
+
expected_dtype = "ARRAY"
|
634
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
635
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
636
|
+
expected_dtype = "ARRAY"
|
637
|
+
else:
|
638
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
639
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
640
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
641
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
642
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
643
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
644
|
+
|
618
645
|
output_df = self._batch_inference(
|
619
646
|
dataset=dataset,
|
620
647
|
inference_method="transform",
|
@@ -630,8 +657,8 @@ class ElasticNetCV(BaseTransformer):
|
|
630
657
|
|
631
658
|
return output_df
|
632
659
|
|
633
|
-
@available_if(
|
634
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
660
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
661
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
635
662
|
""" Method not supported for this class.
|
636
663
|
|
637
664
|
|
@@ -644,13 +671,21 @@ class ElasticNetCV(BaseTransformer):
|
|
644
671
|
Returns:
|
645
672
|
Predicted dataset.
|
646
673
|
"""
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
674
|
+
self.fit(dataset)
|
675
|
+
assert self._sklearn_object is not None
|
676
|
+
return self._sklearn_object.labels_
|
677
|
+
|
678
|
+
|
679
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
680
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
681
|
+
"""
|
682
|
+
Returns:
|
683
|
+
Transformed dataset.
|
684
|
+
"""
|
685
|
+
self.fit(dataset)
|
686
|
+
assert self._sklearn_object is not None
|
687
|
+
return self._sklearn_object.embedding_
|
688
|
+
|
654
689
|
|
655
690
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
656
691
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|