snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LogisticRegression(BaseTransformer):
|
58
70
|
r"""Logistic Regression (aka logit, MaxEnt) classifier
|
59
71
|
For more details on this class, see [sklearn.linear_model.LogisticRegression]
|
@@ -251,7 +263,9 @@ class LogisticRegression(BaseTransformer):
|
|
251
263
|
self.set_label_cols(label_cols)
|
252
264
|
self.set_passthrough_cols(passthrough_cols)
|
253
265
|
self.set_drop_input_cols(drop_input_cols)
|
254
|
-
self.set_sample_weight_col(sample_weight_col)
|
266
|
+
self.set_sample_weight_col(sample_weight_col)
|
267
|
+
self._use_external_memory_version = False
|
268
|
+
self._batch_size = -1
|
255
269
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
256
270
|
|
257
271
|
self._deps = list(deps)
|
@@ -341,11 +355,6 @@ class LogisticRegression(BaseTransformer):
|
|
341
355
|
if isinstance(dataset, DataFrame):
|
342
356
|
session = dataset._session
|
343
357
|
assert session is not None # keep mypy happy
|
344
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
345
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
346
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
347
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
348
|
-
|
349
358
|
# Specify input columns so column pruning will be enforced
|
350
359
|
selected_cols = self._get_active_columns()
|
351
360
|
if len(selected_cols) > 0:
|
@@ -373,7 +382,9 @@ class LogisticRegression(BaseTransformer):
|
|
373
382
|
label_cols=self.label_cols,
|
374
383
|
sample_weight_col=self.sample_weight_col,
|
375
384
|
autogenerated=self._autogenerated,
|
376
|
-
subproject=_SUBPROJECT
|
385
|
+
subproject=_SUBPROJECT,
|
386
|
+
use_external_memory_version=self._use_external_memory_version,
|
387
|
+
batch_size=self._batch_size,
|
377
388
|
)
|
378
389
|
self._sklearn_object = model_trainer.train()
|
379
390
|
self._is_fitted = True
|
@@ -644,6 +655,22 @@ class LogisticRegression(BaseTransformer):
|
|
644
655
|
# each row containing a list of values.
|
645
656
|
expected_dtype = "ARRAY"
|
646
657
|
|
658
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
659
|
+
if expected_dtype == "":
|
660
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
661
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
662
|
+
expected_dtype = "ARRAY"
|
663
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
664
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
665
|
+
expected_dtype = "ARRAY"
|
666
|
+
else:
|
667
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
668
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
669
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
670
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
671
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
672
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
673
|
+
|
647
674
|
output_df = self._batch_inference(
|
648
675
|
dataset=dataset,
|
649
676
|
inference_method="transform",
|
@@ -659,8 +686,8 @@ class LogisticRegression(BaseTransformer):
|
|
659
686
|
|
660
687
|
return output_df
|
661
688
|
|
662
|
-
@available_if(
|
663
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
689
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
690
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
664
691
|
""" Method not supported for this class.
|
665
692
|
|
666
693
|
|
@@ -673,13 +700,21 @@ class LogisticRegression(BaseTransformer):
|
|
673
700
|
Returns:
|
674
701
|
Predicted dataset.
|
675
702
|
"""
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
703
|
+
self.fit(dataset)
|
704
|
+
assert self._sklearn_object is not None
|
705
|
+
return self._sklearn_object.labels_
|
706
|
+
|
707
|
+
|
708
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
709
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
710
|
+
"""
|
711
|
+
Returns:
|
712
|
+
Transformed dataset.
|
713
|
+
"""
|
714
|
+
self.fit(dataset)
|
715
|
+
assert self._sklearn_object is not None
|
716
|
+
return self._sklearn_object.embedding_
|
717
|
+
|
683
718
|
|
684
719
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
685
720
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LogisticRegressionCV(BaseTransformer):
|
58
70
|
r"""Logistic Regression CV (aka logit, MaxEnt) classifier
|
59
71
|
For more details on this class, see [sklearn.linear_model.LogisticRegressionCV]
|
@@ -270,7 +282,9 @@ class LogisticRegressionCV(BaseTransformer):
|
|
270
282
|
self.set_label_cols(label_cols)
|
271
283
|
self.set_passthrough_cols(passthrough_cols)
|
272
284
|
self.set_drop_input_cols(drop_input_cols)
|
273
|
-
self.set_sample_weight_col(sample_weight_col)
|
285
|
+
self.set_sample_weight_col(sample_weight_col)
|
286
|
+
self._use_external_memory_version = False
|
287
|
+
self._batch_size = -1
|
274
288
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
275
289
|
|
276
290
|
self._deps = list(deps)
|
@@ -362,11 +376,6 @@ class LogisticRegressionCV(BaseTransformer):
|
|
362
376
|
if isinstance(dataset, DataFrame):
|
363
377
|
session = dataset._session
|
364
378
|
assert session is not None # keep mypy happy
|
365
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
366
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
367
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
368
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
369
|
-
|
370
379
|
# Specify input columns so column pruning will be enforced
|
371
380
|
selected_cols = self._get_active_columns()
|
372
381
|
if len(selected_cols) > 0:
|
@@ -394,7 +403,9 @@ class LogisticRegressionCV(BaseTransformer):
|
|
394
403
|
label_cols=self.label_cols,
|
395
404
|
sample_weight_col=self.sample_weight_col,
|
396
405
|
autogenerated=self._autogenerated,
|
397
|
-
subproject=_SUBPROJECT
|
406
|
+
subproject=_SUBPROJECT,
|
407
|
+
use_external_memory_version=self._use_external_memory_version,
|
408
|
+
batch_size=self._batch_size,
|
398
409
|
)
|
399
410
|
self._sklearn_object = model_trainer.train()
|
400
411
|
self._is_fitted = True
|
@@ -665,6 +676,22 @@ class LogisticRegressionCV(BaseTransformer):
|
|
665
676
|
# each row containing a list of values.
|
666
677
|
expected_dtype = "ARRAY"
|
667
678
|
|
679
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
680
|
+
if expected_dtype == "":
|
681
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
682
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
683
|
+
expected_dtype = "ARRAY"
|
684
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
685
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
686
|
+
expected_dtype = "ARRAY"
|
687
|
+
else:
|
688
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
689
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
690
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
691
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
692
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
693
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
694
|
+
|
668
695
|
output_df = self._batch_inference(
|
669
696
|
dataset=dataset,
|
670
697
|
inference_method="transform",
|
@@ -680,8 +707,8 @@ class LogisticRegressionCV(BaseTransformer):
|
|
680
707
|
|
681
708
|
return output_df
|
682
709
|
|
683
|
-
@available_if(
|
684
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
710
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
711
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
685
712
|
""" Method not supported for this class.
|
686
713
|
|
687
714
|
|
@@ -694,13 +721,21 @@ class LogisticRegressionCV(BaseTransformer):
|
|
694
721
|
Returns:
|
695
722
|
Predicted dataset.
|
696
723
|
"""
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
724
|
+
self.fit(dataset)
|
725
|
+
assert self._sklearn_object is not None
|
726
|
+
return self._sklearn_object.labels_
|
727
|
+
|
728
|
+
|
729
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
730
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
731
|
+
"""
|
732
|
+
Returns:
|
733
|
+
Transformed dataset.
|
734
|
+
"""
|
735
|
+
self.fit(dataset)
|
736
|
+
assert self._sklearn_object is not None
|
737
|
+
return self._sklearn_object.embedding_
|
738
|
+
|
704
739
|
|
705
740
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
706
741
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MultiTaskElasticNet(BaseTransformer):
|
58
70
|
r"""Multi-task ElasticNet model trained with L1/L2 mixed-norm as regularizer
|
59
71
|
For more details on this class, see [sklearn.linear_model.MultiTaskElasticNet]
|
@@ -176,7 +188,9 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
176
188
|
self.set_label_cols(label_cols)
|
177
189
|
self.set_passthrough_cols(passthrough_cols)
|
178
190
|
self.set_drop_input_cols(drop_input_cols)
|
179
|
-
self.set_sample_weight_col(sample_weight_col)
|
191
|
+
self.set_sample_weight_col(sample_weight_col)
|
192
|
+
self._use_external_memory_version = False
|
193
|
+
self._batch_size = -1
|
180
194
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
181
195
|
|
182
196
|
self._deps = list(deps)
|
@@ -260,11 +274,6 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
260
274
|
if isinstance(dataset, DataFrame):
|
261
275
|
session = dataset._session
|
262
276
|
assert session is not None # keep mypy happy
|
263
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
264
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
265
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
266
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
267
|
-
|
268
277
|
# Specify input columns so column pruning will be enforced
|
269
278
|
selected_cols = self._get_active_columns()
|
270
279
|
if len(selected_cols) > 0:
|
@@ -292,7 +301,9 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
292
301
|
label_cols=self.label_cols,
|
293
302
|
sample_weight_col=self.sample_weight_col,
|
294
303
|
autogenerated=self._autogenerated,
|
295
|
-
subproject=_SUBPROJECT
|
304
|
+
subproject=_SUBPROJECT,
|
305
|
+
use_external_memory_version=self._use_external_memory_version,
|
306
|
+
batch_size=self._batch_size,
|
296
307
|
)
|
297
308
|
self._sklearn_object = model_trainer.train()
|
298
309
|
self._is_fitted = True
|
@@ -563,6 +574,22 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
563
574
|
# each row containing a list of values.
|
564
575
|
expected_dtype = "ARRAY"
|
565
576
|
|
577
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
578
|
+
if expected_dtype == "":
|
579
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
580
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
581
|
+
expected_dtype = "ARRAY"
|
582
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
583
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
584
|
+
expected_dtype = "ARRAY"
|
585
|
+
else:
|
586
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
587
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
588
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
589
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
590
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
591
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
592
|
+
|
566
593
|
output_df = self._batch_inference(
|
567
594
|
dataset=dataset,
|
568
595
|
inference_method="transform",
|
@@ -578,8 +605,8 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
578
605
|
|
579
606
|
return output_df
|
580
607
|
|
581
|
-
@available_if(
|
582
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
608
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
609
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
583
610
|
""" Method not supported for this class.
|
584
611
|
|
585
612
|
|
@@ -592,13 +619,21 @@ class MultiTaskElasticNet(BaseTransformer):
|
|
592
619
|
Returns:
|
593
620
|
Predicted dataset.
|
594
621
|
"""
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
622
|
+
self.fit(dataset)
|
623
|
+
assert self._sklearn_object is not None
|
624
|
+
return self._sklearn_object.labels_
|
625
|
+
|
626
|
+
|
627
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
628
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
629
|
+
"""
|
630
|
+
Returns:
|
631
|
+
Transformed dataset.
|
632
|
+
"""
|
633
|
+
self.fit(dataset)
|
634
|
+
assert self._sklearn_object is not None
|
635
|
+
return self._sklearn_object.embedding_
|
636
|
+
|
602
637
|
|
603
638
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
604
639
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MultiTaskElasticNetCV(BaseTransformer):
|
58
70
|
r"""Multi-task L1/L2 ElasticNet with built-in cross-validation
|
59
71
|
For more details on this class, see [sklearn.linear_model.MultiTaskElasticNetCV]
|
@@ -213,7 +225,9 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
213
225
|
self.set_label_cols(label_cols)
|
214
226
|
self.set_passthrough_cols(passthrough_cols)
|
215
227
|
self.set_drop_input_cols(drop_input_cols)
|
216
|
-
self.set_sample_weight_col(sample_weight_col)
|
228
|
+
self.set_sample_weight_col(sample_weight_col)
|
229
|
+
self._use_external_memory_version = False
|
230
|
+
self._batch_size = -1
|
217
231
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
218
232
|
|
219
233
|
self._deps = list(deps)
|
@@ -301,11 +315,6 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
301
315
|
if isinstance(dataset, DataFrame):
|
302
316
|
session = dataset._session
|
303
317
|
assert session is not None # keep mypy happy
|
304
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
305
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
306
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
307
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
308
|
-
|
309
318
|
# Specify input columns so column pruning will be enforced
|
310
319
|
selected_cols = self._get_active_columns()
|
311
320
|
if len(selected_cols) > 0:
|
@@ -333,7 +342,9 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
333
342
|
label_cols=self.label_cols,
|
334
343
|
sample_weight_col=self.sample_weight_col,
|
335
344
|
autogenerated=self._autogenerated,
|
336
|
-
subproject=_SUBPROJECT
|
345
|
+
subproject=_SUBPROJECT,
|
346
|
+
use_external_memory_version=self._use_external_memory_version,
|
347
|
+
batch_size=self._batch_size,
|
337
348
|
)
|
338
349
|
self._sklearn_object = model_trainer.train()
|
339
350
|
self._is_fitted = True
|
@@ -604,6 +615,22 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
604
615
|
# each row containing a list of values.
|
605
616
|
expected_dtype = "ARRAY"
|
606
617
|
|
618
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
619
|
+
if expected_dtype == "":
|
620
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
621
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
622
|
+
expected_dtype = "ARRAY"
|
623
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
624
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
625
|
+
expected_dtype = "ARRAY"
|
626
|
+
else:
|
627
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
628
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
629
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
630
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
631
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
632
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
633
|
+
|
607
634
|
output_df = self._batch_inference(
|
608
635
|
dataset=dataset,
|
609
636
|
inference_method="transform",
|
@@ -619,8 +646,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
619
646
|
|
620
647
|
return output_df
|
621
648
|
|
622
|
-
@available_if(
|
623
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
649
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
650
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
624
651
|
""" Method not supported for this class.
|
625
652
|
|
626
653
|
|
@@ -633,13 +660,21 @@ class MultiTaskElasticNetCV(BaseTransformer):
|
|
633
660
|
Returns:
|
634
661
|
Predicted dataset.
|
635
662
|
"""
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
663
|
+
self.fit(dataset)
|
664
|
+
assert self._sklearn_object is not None
|
665
|
+
return self._sklearn_object.labels_
|
666
|
+
|
667
|
+
|
668
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
669
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
670
|
+
"""
|
671
|
+
Returns:
|
672
|
+
Transformed dataset.
|
673
|
+
"""
|
674
|
+
self.fit(dataset)
|
675
|
+
assert self._sklearn_object is not None
|
676
|
+
return self._sklearn_object.embedding_
|
677
|
+
|
643
678
|
|
644
679
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
645
680
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MultiTaskLasso(BaseTransformer):
|
58
70
|
r"""Multi-task Lasso model trained with L1/L2 mixed-norm as regularizer
|
59
71
|
For more details on this class, see [sklearn.linear_model.MultiTaskLasso]
|
@@ -169,7 +181,9 @@ class MultiTaskLasso(BaseTransformer):
|
|
169
181
|
self.set_label_cols(label_cols)
|
170
182
|
self.set_passthrough_cols(passthrough_cols)
|
171
183
|
self.set_drop_input_cols(drop_input_cols)
|
172
|
-
self.set_sample_weight_col(sample_weight_col)
|
184
|
+
self.set_sample_weight_col(sample_weight_col)
|
185
|
+
self._use_external_memory_version = False
|
186
|
+
self._batch_size = -1
|
173
187
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
174
188
|
|
175
189
|
self._deps = list(deps)
|
@@ -252,11 +266,6 @@ class MultiTaskLasso(BaseTransformer):
|
|
252
266
|
if isinstance(dataset, DataFrame):
|
253
267
|
session = dataset._session
|
254
268
|
assert session is not None # keep mypy happy
|
255
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
256
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
257
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
258
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
259
|
-
|
260
269
|
# Specify input columns so column pruning will be enforced
|
261
270
|
selected_cols = self._get_active_columns()
|
262
271
|
if len(selected_cols) > 0:
|
@@ -284,7 +293,9 @@ class MultiTaskLasso(BaseTransformer):
|
|
284
293
|
label_cols=self.label_cols,
|
285
294
|
sample_weight_col=self.sample_weight_col,
|
286
295
|
autogenerated=self._autogenerated,
|
287
|
-
subproject=_SUBPROJECT
|
296
|
+
subproject=_SUBPROJECT,
|
297
|
+
use_external_memory_version=self._use_external_memory_version,
|
298
|
+
batch_size=self._batch_size,
|
288
299
|
)
|
289
300
|
self._sklearn_object = model_trainer.train()
|
290
301
|
self._is_fitted = True
|
@@ -555,6 +566,22 @@ class MultiTaskLasso(BaseTransformer):
|
|
555
566
|
# each row containing a list of values.
|
556
567
|
expected_dtype = "ARRAY"
|
557
568
|
|
569
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
570
|
+
if expected_dtype == "":
|
571
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
572
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
573
|
+
expected_dtype = "ARRAY"
|
574
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
575
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
576
|
+
expected_dtype = "ARRAY"
|
577
|
+
else:
|
578
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
579
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
580
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
581
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
582
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
583
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
584
|
+
|
558
585
|
output_df = self._batch_inference(
|
559
586
|
dataset=dataset,
|
560
587
|
inference_method="transform",
|
@@ -570,8 +597,8 @@ class MultiTaskLasso(BaseTransformer):
|
|
570
597
|
|
571
598
|
return output_df
|
572
599
|
|
573
|
-
@available_if(
|
574
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
600
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
601
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
575
602
|
""" Method not supported for this class.
|
576
603
|
|
577
604
|
|
@@ -584,13 +611,21 @@ class MultiTaskLasso(BaseTransformer):
|
|
584
611
|
Returns:
|
585
612
|
Predicted dataset.
|
586
613
|
"""
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
614
|
+
self.fit(dataset)
|
615
|
+
assert self._sklearn_object is not None
|
616
|
+
return self._sklearn_object.labels_
|
617
|
+
|
618
|
+
|
619
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
620
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
621
|
+
"""
|
622
|
+
Returns:
|
623
|
+
Transformed dataset.
|
624
|
+
"""
|
625
|
+
self.fit(dataset)
|
626
|
+
assert self._sklearn_object is not None
|
627
|
+
return self._sklearn_object.embedding_
|
628
|
+
|
594
629
|
|
595
630
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
596
631
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|