snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LassoCV(BaseTransformer):
|
58
70
|
r"""Lasso linear model with iterative fitting along a regularization path
|
59
71
|
For more details on this class, see [sklearn.linear_model.LassoCV]
|
@@ -209,7 +221,9 @@ class LassoCV(BaseTransformer):
|
|
209
221
|
self.set_label_cols(label_cols)
|
210
222
|
self.set_passthrough_cols(passthrough_cols)
|
211
223
|
self.set_drop_input_cols(drop_input_cols)
|
212
|
-
self.set_sample_weight_col(sample_weight_col)
|
224
|
+
self.set_sample_weight_col(sample_weight_col)
|
225
|
+
self._use_external_memory_version = False
|
226
|
+
self._batch_size = -1
|
213
227
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
214
228
|
|
215
229
|
self._deps = list(deps)
|
@@ -298,11 +312,6 @@ class LassoCV(BaseTransformer):
|
|
298
312
|
if isinstance(dataset, DataFrame):
|
299
313
|
session = dataset._session
|
300
314
|
assert session is not None # keep mypy happy
|
301
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
302
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
303
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
304
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
305
|
-
|
306
315
|
# Specify input columns so column pruning will be enforced
|
307
316
|
selected_cols = self._get_active_columns()
|
308
317
|
if len(selected_cols) > 0:
|
@@ -330,7 +339,9 @@ class LassoCV(BaseTransformer):
|
|
330
339
|
label_cols=self.label_cols,
|
331
340
|
sample_weight_col=self.sample_weight_col,
|
332
341
|
autogenerated=self._autogenerated,
|
333
|
-
subproject=_SUBPROJECT
|
342
|
+
subproject=_SUBPROJECT,
|
343
|
+
use_external_memory_version=self._use_external_memory_version,
|
344
|
+
batch_size=self._batch_size,
|
334
345
|
)
|
335
346
|
self._sklearn_object = model_trainer.train()
|
336
347
|
self._is_fitted = True
|
@@ -601,6 +612,22 @@ class LassoCV(BaseTransformer):
|
|
601
612
|
# each row containing a list of values.
|
602
613
|
expected_dtype = "ARRAY"
|
603
614
|
|
615
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
616
|
+
if expected_dtype == "":
|
617
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
618
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
619
|
+
expected_dtype = "ARRAY"
|
620
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
621
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
622
|
+
expected_dtype = "ARRAY"
|
623
|
+
else:
|
624
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
625
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
626
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
627
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
628
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
629
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
630
|
+
|
604
631
|
output_df = self._batch_inference(
|
605
632
|
dataset=dataset,
|
606
633
|
inference_method="transform",
|
@@ -616,8 +643,8 @@ class LassoCV(BaseTransformer):
|
|
616
643
|
|
617
644
|
return output_df
|
618
645
|
|
619
|
-
@available_if(
|
620
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
646
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
647
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
621
648
|
""" Method not supported for this class.
|
622
649
|
|
623
650
|
|
@@ -630,13 +657,21 @@ class LassoCV(BaseTransformer):
|
|
630
657
|
Returns:
|
631
658
|
Predicted dataset.
|
632
659
|
"""
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
660
|
+
self.fit(dataset)
|
661
|
+
assert self._sklearn_object is not None
|
662
|
+
return self._sklearn_object.labels_
|
663
|
+
|
664
|
+
|
665
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
666
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
667
|
+
"""
|
668
|
+
Returns:
|
669
|
+
Transformed dataset.
|
670
|
+
"""
|
671
|
+
self.fit(dataset)
|
672
|
+
assert self._sklearn_object is not None
|
673
|
+
return self._sklearn_object.embedding_
|
674
|
+
|
640
675
|
|
641
676
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
642
677
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LassoLars(BaseTransformer):
|
58
70
|
r"""Lasso model fit with Least Angle Regression a
|
59
71
|
For more details on this class, see [sklearn.linear_model.LassoLars]
|
@@ -203,7 +215,9 @@ class LassoLars(BaseTransformer):
|
|
203
215
|
self.set_label_cols(label_cols)
|
204
216
|
self.set_passthrough_cols(passthrough_cols)
|
205
217
|
self.set_drop_input_cols(drop_input_cols)
|
206
|
-
self.set_sample_weight_col(sample_weight_col)
|
218
|
+
self.set_sample_weight_col(sample_weight_col)
|
219
|
+
self._use_external_memory_version = False
|
220
|
+
self._batch_size = -1
|
207
221
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
208
222
|
|
209
223
|
self._deps = list(deps)
|
@@ -290,11 +304,6 @@ class LassoLars(BaseTransformer):
|
|
290
304
|
if isinstance(dataset, DataFrame):
|
291
305
|
session = dataset._session
|
292
306
|
assert session is not None # keep mypy happy
|
293
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
294
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
295
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
296
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
297
|
-
|
298
307
|
# Specify input columns so column pruning will be enforced
|
299
308
|
selected_cols = self._get_active_columns()
|
300
309
|
if len(selected_cols) > 0:
|
@@ -322,7 +331,9 @@ class LassoLars(BaseTransformer):
|
|
322
331
|
label_cols=self.label_cols,
|
323
332
|
sample_weight_col=self.sample_weight_col,
|
324
333
|
autogenerated=self._autogenerated,
|
325
|
-
subproject=_SUBPROJECT
|
334
|
+
subproject=_SUBPROJECT,
|
335
|
+
use_external_memory_version=self._use_external_memory_version,
|
336
|
+
batch_size=self._batch_size,
|
326
337
|
)
|
327
338
|
self._sklearn_object = model_trainer.train()
|
328
339
|
self._is_fitted = True
|
@@ -593,6 +604,22 @@ class LassoLars(BaseTransformer):
|
|
593
604
|
# each row containing a list of values.
|
594
605
|
expected_dtype = "ARRAY"
|
595
606
|
|
607
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
608
|
+
if expected_dtype == "":
|
609
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
610
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
611
|
+
expected_dtype = "ARRAY"
|
612
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
613
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
614
|
+
expected_dtype = "ARRAY"
|
615
|
+
else:
|
616
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
617
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
618
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
619
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
620
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
621
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
622
|
+
|
596
623
|
output_df = self._batch_inference(
|
597
624
|
dataset=dataset,
|
598
625
|
inference_method="transform",
|
@@ -608,8 +635,8 @@ class LassoLars(BaseTransformer):
|
|
608
635
|
|
609
636
|
return output_df
|
610
637
|
|
611
|
-
@available_if(
|
612
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
638
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
639
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
613
640
|
""" Method not supported for this class.
|
614
641
|
|
615
642
|
|
@@ -622,13 +649,21 @@ class LassoLars(BaseTransformer):
|
|
622
649
|
Returns:
|
623
650
|
Predicted dataset.
|
624
651
|
"""
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
652
|
+
self.fit(dataset)
|
653
|
+
assert self._sklearn_object is not None
|
654
|
+
return self._sklearn_object.labels_
|
655
|
+
|
656
|
+
|
657
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
658
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
659
|
+
"""
|
660
|
+
Returns:
|
661
|
+
Transformed dataset.
|
662
|
+
"""
|
663
|
+
self.fit(dataset)
|
664
|
+
assert self._sklearn_object is not None
|
665
|
+
return self._sklearn_object.embedding_
|
666
|
+
|
632
667
|
|
633
668
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
634
669
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LassoLarsCV(BaseTransformer):
|
58
70
|
r"""Cross-validated Lasso, using the LARS algorithm
|
59
71
|
For more details on this class, see [sklearn.linear_model.LassoLarsCV]
|
@@ -205,7 +217,9 @@ class LassoLarsCV(BaseTransformer):
|
|
205
217
|
self.set_label_cols(label_cols)
|
206
218
|
self.set_passthrough_cols(passthrough_cols)
|
207
219
|
self.set_drop_input_cols(drop_input_cols)
|
208
|
-
self.set_sample_weight_col(sample_weight_col)
|
220
|
+
self.set_sample_weight_col(sample_weight_col)
|
221
|
+
self._use_external_memory_version = False
|
222
|
+
self._batch_size = -1
|
209
223
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
210
224
|
|
211
225
|
self._deps = list(deps)
|
@@ -291,11 +305,6 @@ class LassoLarsCV(BaseTransformer):
|
|
291
305
|
if isinstance(dataset, DataFrame):
|
292
306
|
session = dataset._session
|
293
307
|
assert session is not None # keep mypy happy
|
294
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
295
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
296
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
297
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
298
|
-
|
299
308
|
# Specify input columns so column pruning will be enforced
|
300
309
|
selected_cols = self._get_active_columns()
|
301
310
|
if len(selected_cols) > 0:
|
@@ -323,7 +332,9 @@ class LassoLarsCV(BaseTransformer):
|
|
323
332
|
label_cols=self.label_cols,
|
324
333
|
sample_weight_col=self.sample_weight_col,
|
325
334
|
autogenerated=self._autogenerated,
|
326
|
-
subproject=_SUBPROJECT
|
335
|
+
subproject=_SUBPROJECT,
|
336
|
+
use_external_memory_version=self._use_external_memory_version,
|
337
|
+
batch_size=self._batch_size,
|
327
338
|
)
|
328
339
|
self._sklearn_object = model_trainer.train()
|
329
340
|
self._is_fitted = True
|
@@ -594,6 +605,22 @@ class LassoLarsCV(BaseTransformer):
|
|
594
605
|
# each row containing a list of values.
|
595
606
|
expected_dtype = "ARRAY"
|
596
607
|
|
608
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
609
|
+
if expected_dtype == "":
|
610
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
611
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
612
|
+
expected_dtype = "ARRAY"
|
613
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
614
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
615
|
+
expected_dtype = "ARRAY"
|
616
|
+
else:
|
617
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
618
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
619
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
620
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
621
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
622
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
623
|
+
|
597
624
|
output_df = self._batch_inference(
|
598
625
|
dataset=dataset,
|
599
626
|
inference_method="transform",
|
@@ -609,8 +636,8 @@ class LassoLarsCV(BaseTransformer):
|
|
609
636
|
|
610
637
|
return output_df
|
611
638
|
|
612
|
-
@available_if(
|
613
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
639
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
640
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
614
641
|
""" Method not supported for this class.
|
615
642
|
|
616
643
|
|
@@ -623,13 +650,21 @@ class LassoLarsCV(BaseTransformer):
|
|
623
650
|
Returns:
|
624
651
|
Predicted dataset.
|
625
652
|
"""
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
653
|
+
self.fit(dataset)
|
654
|
+
assert self._sklearn_object is not None
|
655
|
+
return self._sklearn_object.labels_
|
656
|
+
|
657
|
+
|
658
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
659
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
660
|
+
"""
|
661
|
+
Returns:
|
662
|
+
Transformed dataset.
|
663
|
+
"""
|
664
|
+
self.fit(dataset)
|
665
|
+
assert self._sklearn_object is not None
|
666
|
+
return self._sklearn_object.embedding_
|
667
|
+
|
633
668
|
|
634
669
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
635
670
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LassoLarsIC(BaseTransformer):
|
58
70
|
r"""Lasso model fit with Lars using BIC or AIC for model selection
|
59
71
|
For more details on this class, see [sklearn.linear_model.LassoLarsIC]
|
@@ -189,7 +201,9 @@ class LassoLarsIC(BaseTransformer):
|
|
189
201
|
self.set_label_cols(label_cols)
|
190
202
|
self.set_passthrough_cols(passthrough_cols)
|
191
203
|
self.set_drop_input_cols(drop_input_cols)
|
192
|
-
self.set_sample_weight_col(sample_weight_col)
|
204
|
+
self.set_sample_weight_col(sample_weight_col)
|
205
|
+
self._use_external_memory_version = False
|
206
|
+
self._batch_size = -1
|
193
207
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
194
208
|
|
195
209
|
self._deps = list(deps)
|
@@ -274,11 +288,6 @@ class LassoLarsIC(BaseTransformer):
|
|
274
288
|
if isinstance(dataset, DataFrame):
|
275
289
|
session = dataset._session
|
276
290
|
assert session is not None # keep mypy happy
|
277
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
278
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
279
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
280
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
281
|
-
|
282
291
|
# Specify input columns so column pruning will be enforced
|
283
292
|
selected_cols = self._get_active_columns()
|
284
293
|
if len(selected_cols) > 0:
|
@@ -306,7 +315,9 @@ class LassoLarsIC(BaseTransformer):
|
|
306
315
|
label_cols=self.label_cols,
|
307
316
|
sample_weight_col=self.sample_weight_col,
|
308
317
|
autogenerated=self._autogenerated,
|
309
|
-
subproject=_SUBPROJECT
|
318
|
+
subproject=_SUBPROJECT,
|
319
|
+
use_external_memory_version=self._use_external_memory_version,
|
320
|
+
batch_size=self._batch_size,
|
310
321
|
)
|
311
322
|
self._sklearn_object = model_trainer.train()
|
312
323
|
self._is_fitted = True
|
@@ -577,6 +588,22 @@ class LassoLarsIC(BaseTransformer):
|
|
577
588
|
# each row containing a list of values.
|
578
589
|
expected_dtype = "ARRAY"
|
579
590
|
|
591
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
592
|
+
if expected_dtype == "":
|
593
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
594
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
595
|
+
expected_dtype = "ARRAY"
|
596
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
597
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
598
|
+
expected_dtype = "ARRAY"
|
599
|
+
else:
|
600
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
601
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
602
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
603
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
604
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
605
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
606
|
+
|
580
607
|
output_df = self._batch_inference(
|
581
608
|
dataset=dataset,
|
582
609
|
inference_method="transform",
|
@@ -592,8 +619,8 @@ class LassoLarsIC(BaseTransformer):
|
|
592
619
|
|
593
620
|
return output_df
|
594
621
|
|
595
|
-
@available_if(
|
596
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
622
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
623
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
597
624
|
""" Method not supported for this class.
|
598
625
|
|
599
626
|
|
@@ -606,13 +633,21 @@ class LassoLarsIC(BaseTransformer):
|
|
606
633
|
Returns:
|
607
634
|
Predicted dataset.
|
608
635
|
"""
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
636
|
+
self.fit(dataset)
|
637
|
+
assert self._sklearn_object is not None
|
638
|
+
return self._sklearn_object.labels_
|
639
|
+
|
640
|
+
|
641
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
642
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
643
|
+
"""
|
644
|
+
Returns:
|
645
|
+
Transformed dataset.
|
646
|
+
"""
|
647
|
+
self.fit(dataset)
|
648
|
+
assert self._sklearn_object is not None
|
649
|
+
return self._sklearn_object.embedding_
|
650
|
+
|
616
651
|
|
617
652
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
618
653
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LinearRegression(BaseTransformer):
|
58
70
|
r"""Ordinary least squares Linear Regression
|
59
71
|
For more details on this class, see [sklearn.linear_model.LinearRegression]
|
@@ -148,7 +160,9 @@ class LinearRegression(BaseTransformer):
|
|
148
160
|
self.set_label_cols(label_cols)
|
149
161
|
self.set_passthrough_cols(passthrough_cols)
|
150
162
|
self.set_drop_input_cols(drop_input_cols)
|
151
|
-
self.set_sample_weight_col(sample_weight_col)
|
163
|
+
self.set_sample_weight_col(sample_weight_col)
|
164
|
+
self._use_external_memory_version = False
|
165
|
+
self._batch_size = -1
|
152
166
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
153
167
|
|
154
168
|
self._deps = list(deps)
|
@@ -227,11 +241,6 @@ class LinearRegression(BaseTransformer):
|
|
227
241
|
if isinstance(dataset, DataFrame):
|
228
242
|
session = dataset._session
|
229
243
|
assert session is not None # keep mypy happy
|
230
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
231
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
232
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
233
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
234
|
-
|
235
244
|
# Specify input columns so column pruning will be enforced
|
236
245
|
selected_cols = self._get_active_columns()
|
237
246
|
if len(selected_cols) > 0:
|
@@ -259,7 +268,9 @@ class LinearRegression(BaseTransformer):
|
|
259
268
|
label_cols=self.label_cols,
|
260
269
|
sample_weight_col=self.sample_weight_col,
|
261
270
|
autogenerated=self._autogenerated,
|
262
|
-
subproject=_SUBPROJECT
|
271
|
+
subproject=_SUBPROJECT,
|
272
|
+
use_external_memory_version=self._use_external_memory_version,
|
273
|
+
batch_size=self._batch_size,
|
263
274
|
)
|
264
275
|
self._sklearn_object = model_trainer.train()
|
265
276
|
self._is_fitted = True
|
@@ -530,6 +541,22 @@ class LinearRegression(BaseTransformer):
|
|
530
541
|
# each row containing a list of values.
|
531
542
|
expected_dtype = "ARRAY"
|
532
543
|
|
544
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
545
|
+
if expected_dtype == "":
|
546
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
547
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
548
|
+
expected_dtype = "ARRAY"
|
549
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
550
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
551
|
+
expected_dtype = "ARRAY"
|
552
|
+
else:
|
553
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
554
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
555
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
556
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
557
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
558
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
559
|
+
|
533
560
|
output_df = self._batch_inference(
|
534
561
|
dataset=dataset,
|
535
562
|
inference_method="transform",
|
@@ -545,8 +572,8 @@ class LinearRegression(BaseTransformer):
|
|
545
572
|
|
546
573
|
return output_df
|
547
574
|
|
548
|
-
@available_if(
|
549
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
575
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
576
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
550
577
|
""" Method not supported for this class.
|
551
578
|
|
552
579
|
|
@@ -559,13 +586,21 @@ class LinearRegression(BaseTransformer):
|
|
559
586
|
Returns:
|
560
587
|
Predicted dataset.
|
561
588
|
"""
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
589
|
+
self.fit(dataset)
|
590
|
+
assert self._sklearn_object is not None
|
591
|
+
return self._sklearn_object.labels_
|
592
|
+
|
593
|
+
|
594
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
595
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
596
|
+
"""
|
597
|
+
Returns:
|
598
|
+
Transformed dataset.
|
599
|
+
"""
|
600
|
+
self.fit(dataset)
|
601
|
+
assert self._sklearn_object is not None
|
602
|
+
return self._sklearn_object.embedding_
|
603
|
+
|
569
604
|
|
570
605
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
571
606
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|