snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GammaRegressor(BaseTransformer):
|
58
70
|
r"""Generalized Linear Model with a Gamma distribution
|
59
71
|
For more details on this class, see [sklearn.linear_model.GammaRegressor]
|
@@ -175,7 +187,9 @@ class GammaRegressor(BaseTransformer):
|
|
175
187
|
self.set_label_cols(label_cols)
|
176
188
|
self.set_passthrough_cols(passthrough_cols)
|
177
189
|
self.set_drop_input_cols(drop_input_cols)
|
178
|
-
self.set_sample_weight_col(sample_weight_col)
|
190
|
+
self.set_sample_weight_col(sample_weight_col)
|
191
|
+
self._use_external_memory_version = False
|
192
|
+
self._batch_size = -1
|
179
193
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
180
194
|
|
181
195
|
self._deps = list(deps)
|
@@ -257,11 +271,6 @@ class GammaRegressor(BaseTransformer):
|
|
257
271
|
if isinstance(dataset, DataFrame):
|
258
272
|
session = dataset._session
|
259
273
|
assert session is not None # keep mypy happy
|
260
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
261
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
262
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
263
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
264
|
-
|
265
274
|
# Specify input columns so column pruning will be enforced
|
266
275
|
selected_cols = self._get_active_columns()
|
267
276
|
if len(selected_cols) > 0:
|
@@ -289,7 +298,9 @@ class GammaRegressor(BaseTransformer):
|
|
289
298
|
label_cols=self.label_cols,
|
290
299
|
sample_weight_col=self.sample_weight_col,
|
291
300
|
autogenerated=self._autogenerated,
|
292
|
-
subproject=_SUBPROJECT
|
301
|
+
subproject=_SUBPROJECT,
|
302
|
+
use_external_memory_version=self._use_external_memory_version,
|
303
|
+
batch_size=self._batch_size,
|
293
304
|
)
|
294
305
|
self._sklearn_object = model_trainer.train()
|
295
306
|
self._is_fitted = True
|
@@ -560,6 +571,22 @@ class GammaRegressor(BaseTransformer):
|
|
560
571
|
# each row containing a list of values.
|
561
572
|
expected_dtype = "ARRAY"
|
562
573
|
|
574
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
575
|
+
if expected_dtype == "":
|
576
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
577
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
578
|
+
expected_dtype = "ARRAY"
|
579
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
580
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
581
|
+
expected_dtype = "ARRAY"
|
582
|
+
else:
|
583
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
584
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
585
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
586
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
587
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
588
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
589
|
+
|
563
590
|
output_df = self._batch_inference(
|
564
591
|
dataset=dataset,
|
565
592
|
inference_method="transform",
|
@@ -575,8 +602,8 @@ class GammaRegressor(BaseTransformer):
|
|
575
602
|
|
576
603
|
return output_df
|
577
604
|
|
578
|
-
@available_if(
|
579
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
605
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
606
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
580
607
|
""" Method not supported for this class.
|
581
608
|
|
582
609
|
|
@@ -589,13 +616,21 @@ class GammaRegressor(BaseTransformer):
|
|
589
616
|
Returns:
|
590
617
|
Predicted dataset.
|
591
618
|
"""
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
619
|
+
self.fit(dataset)
|
620
|
+
assert self._sklearn_object is not None
|
621
|
+
return self._sklearn_object.labels_
|
622
|
+
|
623
|
+
|
624
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
625
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
626
|
+
"""
|
627
|
+
Returns:
|
628
|
+
Transformed dataset.
|
629
|
+
"""
|
630
|
+
self.fit(dataset)
|
631
|
+
assert self._sklearn_object is not None
|
632
|
+
return self._sklearn_object.embedding_
|
633
|
+
|
599
634
|
|
600
635
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
601
636
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class HuberRegressor(BaseTransformer):
|
58
70
|
r"""L2-regularized linear regression model that is robust to outliers
|
59
71
|
For more details on this class, see [sklearn.linear_model.HuberRegressor]
|
@@ -159,7 +171,9 @@ class HuberRegressor(BaseTransformer):
|
|
159
171
|
self.set_label_cols(label_cols)
|
160
172
|
self.set_passthrough_cols(passthrough_cols)
|
161
173
|
self.set_drop_input_cols(drop_input_cols)
|
162
|
-
self.set_sample_weight_col(sample_weight_col)
|
174
|
+
self.set_sample_weight_col(sample_weight_col)
|
175
|
+
self._use_external_memory_version = False
|
176
|
+
self._batch_size = -1
|
163
177
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
164
178
|
|
165
179
|
self._deps = list(deps)
|
@@ -240,11 +254,6 @@ class HuberRegressor(BaseTransformer):
|
|
240
254
|
if isinstance(dataset, DataFrame):
|
241
255
|
session = dataset._session
|
242
256
|
assert session is not None # keep mypy happy
|
243
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
244
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
245
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
246
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
247
|
-
|
248
257
|
# Specify input columns so column pruning will be enforced
|
249
258
|
selected_cols = self._get_active_columns()
|
250
259
|
if len(selected_cols) > 0:
|
@@ -272,7 +281,9 @@ class HuberRegressor(BaseTransformer):
|
|
272
281
|
label_cols=self.label_cols,
|
273
282
|
sample_weight_col=self.sample_weight_col,
|
274
283
|
autogenerated=self._autogenerated,
|
275
|
-
subproject=_SUBPROJECT
|
284
|
+
subproject=_SUBPROJECT,
|
285
|
+
use_external_memory_version=self._use_external_memory_version,
|
286
|
+
batch_size=self._batch_size,
|
276
287
|
)
|
277
288
|
self._sklearn_object = model_trainer.train()
|
278
289
|
self._is_fitted = True
|
@@ -543,6 +554,22 @@ class HuberRegressor(BaseTransformer):
|
|
543
554
|
# each row containing a list of values.
|
544
555
|
expected_dtype = "ARRAY"
|
545
556
|
|
557
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
558
|
+
if expected_dtype == "":
|
559
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
560
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
561
|
+
expected_dtype = "ARRAY"
|
562
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
563
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
564
|
+
expected_dtype = "ARRAY"
|
565
|
+
else:
|
566
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
567
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
568
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
569
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
570
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
571
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
572
|
+
|
546
573
|
output_df = self._batch_inference(
|
547
574
|
dataset=dataset,
|
548
575
|
inference_method="transform",
|
@@ -558,8 +585,8 @@ class HuberRegressor(BaseTransformer):
|
|
558
585
|
|
559
586
|
return output_df
|
560
587
|
|
561
|
-
@available_if(
|
562
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
588
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
589
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
563
590
|
""" Method not supported for this class.
|
564
591
|
|
565
592
|
|
@@ -572,13 +599,21 @@ class HuberRegressor(BaseTransformer):
|
|
572
599
|
Returns:
|
573
600
|
Predicted dataset.
|
574
601
|
"""
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
602
|
+
self.fit(dataset)
|
603
|
+
assert self._sklearn_object is not None
|
604
|
+
return self._sklearn_object.labels_
|
605
|
+
|
606
|
+
|
607
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
608
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
609
|
+
"""
|
610
|
+
Returns:
|
611
|
+
Transformed dataset.
|
612
|
+
"""
|
613
|
+
self.fit(dataset)
|
614
|
+
assert self._sklearn_object is not None
|
615
|
+
return self._sklearn_object.embedding_
|
616
|
+
|
582
617
|
|
583
618
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
584
619
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class Lars(BaseTransformer):
|
58
70
|
r"""Least Angle Regression model a
|
59
71
|
For more details on this class, see [sklearn.linear_model.Lars]
|
@@ -184,7 +196,9 @@ class Lars(BaseTransformer):
|
|
184
196
|
self.set_label_cols(label_cols)
|
185
197
|
self.set_passthrough_cols(passthrough_cols)
|
186
198
|
self.set_drop_input_cols(drop_input_cols)
|
187
|
-
self.set_sample_weight_col(sample_weight_col)
|
199
|
+
self.set_sample_weight_col(sample_weight_col)
|
200
|
+
self._use_external_memory_version = False
|
201
|
+
self._batch_size = -1
|
188
202
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
189
203
|
|
190
204
|
self._deps = list(deps)
|
@@ -269,11 +283,6 @@ class Lars(BaseTransformer):
|
|
269
283
|
if isinstance(dataset, DataFrame):
|
270
284
|
session = dataset._session
|
271
285
|
assert session is not None # keep mypy happy
|
272
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
273
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
274
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
275
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
276
|
-
|
277
286
|
# Specify input columns so column pruning will be enforced
|
278
287
|
selected_cols = self._get_active_columns()
|
279
288
|
if len(selected_cols) > 0:
|
@@ -301,7 +310,9 @@ class Lars(BaseTransformer):
|
|
301
310
|
label_cols=self.label_cols,
|
302
311
|
sample_weight_col=self.sample_weight_col,
|
303
312
|
autogenerated=self._autogenerated,
|
304
|
-
subproject=_SUBPROJECT
|
313
|
+
subproject=_SUBPROJECT,
|
314
|
+
use_external_memory_version=self._use_external_memory_version,
|
315
|
+
batch_size=self._batch_size,
|
305
316
|
)
|
306
317
|
self._sklearn_object = model_trainer.train()
|
307
318
|
self._is_fitted = True
|
@@ -572,6 +583,22 @@ class Lars(BaseTransformer):
|
|
572
583
|
# each row containing a list of values.
|
573
584
|
expected_dtype = "ARRAY"
|
574
585
|
|
586
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
587
|
+
if expected_dtype == "":
|
588
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
589
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
590
|
+
expected_dtype = "ARRAY"
|
591
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
592
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
593
|
+
expected_dtype = "ARRAY"
|
594
|
+
else:
|
595
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
596
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
597
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
598
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
599
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
600
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
601
|
+
|
575
602
|
output_df = self._batch_inference(
|
576
603
|
dataset=dataset,
|
577
604
|
inference_method="transform",
|
@@ -587,8 +614,8 @@ class Lars(BaseTransformer):
|
|
587
614
|
|
588
615
|
return output_df
|
589
616
|
|
590
|
-
@available_if(
|
591
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
617
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
618
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
592
619
|
""" Method not supported for this class.
|
593
620
|
|
594
621
|
|
@@ -601,13 +628,21 @@ class Lars(BaseTransformer):
|
|
601
628
|
Returns:
|
602
629
|
Predicted dataset.
|
603
630
|
"""
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
631
|
+
self.fit(dataset)
|
632
|
+
assert self._sklearn_object is not None
|
633
|
+
return self._sklearn_object.labels_
|
634
|
+
|
635
|
+
|
636
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
637
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
638
|
+
"""
|
639
|
+
Returns:
|
640
|
+
Transformed dataset.
|
641
|
+
"""
|
642
|
+
self.fit(dataset)
|
643
|
+
assert self._sklearn_object is not None
|
644
|
+
return self._sklearn_object.embedding_
|
645
|
+
|
611
646
|
|
612
647
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
613
648
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LarsCV(BaseTransformer):
|
58
70
|
r"""Cross-validated Least Angle Regression model
|
59
71
|
For more details on this class, see [sklearn.linear_model.LarsCV]
|
@@ -192,7 +204,9 @@ class LarsCV(BaseTransformer):
|
|
192
204
|
self.set_label_cols(label_cols)
|
193
205
|
self.set_passthrough_cols(passthrough_cols)
|
194
206
|
self.set_drop_input_cols(drop_input_cols)
|
195
|
-
self.set_sample_weight_col(sample_weight_col)
|
207
|
+
self.set_sample_weight_col(sample_weight_col)
|
208
|
+
self._use_external_memory_version = False
|
209
|
+
self._batch_size = -1
|
196
210
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
197
211
|
|
198
212
|
self._deps = list(deps)
|
@@ -277,11 +291,6 @@ class LarsCV(BaseTransformer):
|
|
277
291
|
if isinstance(dataset, DataFrame):
|
278
292
|
session = dataset._session
|
279
293
|
assert session is not None # keep mypy happy
|
280
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
281
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
282
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
283
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
284
|
-
|
285
294
|
# Specify input columns so column pruning will be enforced
|
286
295
|
selected_cols = self._get_active_columns()
|
287
296
|
if len(selected_cols) > 0:
|
@@ -309,7 +318,9 @@ class LarsCV(BaseTransformer):
|
|
309
318
|
label_cols=self.label_cols,
|
310
319
|
sample_weight_col=self.sample_weight_col,
|
311
320
|
autogenerated=self._autogenerated,
|
312
|
-
subproject=_SUBPROJECT
|
321
|
+
subproject=_SUBPROJECT,
|
322
|
+
use_external_memory_version=self._use_external_memory_version,
|
323
|
+
batch_size=self._batch_size,
|
313
324
|
)
|
314
325
|
self._sklearn_object = model_trainer.train()
|
315
326
|
self._is_fitted = True
|
@@ -580,6 +591,22 @@ class LarsCV(BaseTransformer):
|
|
580
591
|
# each row containing a list of values.
|
581
592
|
expected_dtype = "ARRAY"
|
582
593
|
|
594
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
595
|
+
if expected_dtype == "":
|
596
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
597
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
598
|
+
expected_dtype = "ARRAY"
|
599
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
600
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
601
|
+
expected_dtype = "ARRAY"
|
602
|
+
else:
|
603
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
604
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
605
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
606
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
607
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
608
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
609
|
+
|
583
610
|
output_df = self._batch_inference(
|
584
611
|
dataset=dataset,
|
585
612
|
inference_method="transform",
|
@@ -595,8 +622,8 @@ class LarsCV(BaseTransformer):
|
|
595
622
|
|
596
623
|
return output_df
|
597
624
|
|
598
|
-
@available_if(
|
599
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
625
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
626
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
600
627
|
""" Method not supported for this class.
|
601
628
|
|
602
629
|
|
@@ -609,13 +636,21 @@ class LarsCV(BaseTransformer):
|
|
609
636
|
Returns:
|
610
637
|
Predicted dataset.
|
611
638
|
"""
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
639
|
+
self.fit(dataset)
|
640
|
+
assert self._sklearn_object is not None
|
641
|
+
return self._sklearn_object.labels_
|
642
|
+
|
643
|
+
|
644
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
645
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
646
|
+
"""
|
647
|
+
Returns:
|
648
|
+
Transformed dataset.
|
649
|
+
"""
|
650
|
+
self.fit(dataset)
|
651
|
+
assert self._sklearn_object is not None
|
652
|
+
return self._sklearn_object.embedding_
|
653
|
+
|
619
654
|
|
620
655
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
621
656
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class Lasso(BaseTransformer):
|
58
70
|
r"""Linear Model trained with L1 prior as regularizer (aka the Lasso)
|
59
71
|
For more details on this class, see [sklearn.linear_model.Lasso]
|
@@ -185,7 +197,9 @@ class Lasso(BaseTransformer):
|
|
185
197
|
self.set_label_cols(label_cols)
|
186
198
|
self.set_passthrough_cols(passthrough_cols)
|
187
199
|
self.set_drop_input_cols(drop_input_cols)
|
188
|
-
self.set_sample_weight_col(sample_weight_col)
|
200
|
+
self.set_sample_weight_col(sample_weight_col)
|
201
|
+
self._use_external_memory_version = False
|
202
|
+
self._batch_size = -1
|
189
203
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
190
204
|
|
191
205
|
self._deps = list(deps)
|
@@ -270,11 +284,6 @@ class Lasso(BaseTransformer):
|
|
270
284
|
if isinstance(dataset, DataFrame):
|
271
285
|
session = dataset._session
|
272
286
|
assert session is not None # keep mypy happy
|
273
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
274
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
275
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
276
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
277
|
-
|
278
287
|
# Specify input columns so column pruning will be enforced
|
279
288
|
selected_cols = self._get_active_columns()
|
280
289
|
if len(selected_cols) > 0:
|
@@ -302,7 +311,9 @@ class Lasso(BaseTransformer):
|
|
302
311
|
label_cols=self.label_cols,
|
303
312
|
sample_weight_col=self.sample_weight_col,
|
304
313
|
autogenerated=self._autogenerated,
|
305
|
-
subproject=_SUBPROJECT
|
314
|
+
subproject=_SUBPROJECT,
|
315
|
+
use_external_memory_version=self._use_external_memory_version,
|
316
|
+
batch_size=self._batch_size,
|
306
317
|
)
|
307
318
|
self._sklearn_object = model_trainer.train()
|
308
319
|
self._is_fitted = True
|
@@ -573,6 +584,22 @@ class Lasso(BaseTransformer):
|
|
573
584
|
# each row containing a list of values.
|
574
585
|
expected_dtype = "ARRAY"
|
575
586
|
|
587
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
588
|
+
if expected_dtype == "":
|
589
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
590
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
591
|
+
expected_dtype = "ARRAY"
|
592
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
593
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
594
|
+
expected_dtype = "ARRAY"
|
595
|
+
else:
|
596
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
597
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
598
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
599
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
600
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
601
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
602
|
+
|
576
603
|
output_df = self._batch_inference(
|
577
604
|
dataset=dataset,
|
578
605
|
inference_method="transform",
|
@@ -588,8 +615,8 @@ class Lasso(BaseTransformer):
|
|
588
615
|
|
589
616
|
return output_df
|
590
617
|
|
591
|
-
@available_if(
|
592
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
618
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
619
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
593
620
|
""" Method not supported for this class.
|
594
621
|
|
595
622
|
|
@@ -602,13 +629,21 @@ class Lasso(BaseTransformer):
|
|
602
629
|
Returns:
|
603
630
|
Predicted dataset.
|
604
631
|
"""
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
632
|
+
self.fit(dataset)
|
633
|
+
assert self._sklearn_object is not None
|
634
|
+
return self._sklearn_object.labels_
|
635
|
+
|
636
|
+
|
637
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
638
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
639
|
+
"""
|
640
|
+
Returns:
|
641
|
+
Transformed dataset.
|
642
|
+
"""
|
643
|
+
self.fit(dataset)
|
644
|
+
assert self._sklearn_object is not None
|
645
|
+
return self._sklearn_object.embedding_
|
646
|
+
|
612
647
|
|
613
648
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
614
649
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|