snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class RidgeCV(BaseTransformer):
|
58
70
|
r"""Ridge regression with built-in cross-validation
|
59
71
|
For more details on this class, see [sklearn.linear_model.RidgeCV]
|
@@ -195,7 +207,9 @@ class RidgeCV(BaseTransformer):
|
|
195
207
|
self.set_label_cols(label_cols)
|
196
208
|
self.set_passthrough_cols(passthrough_cols)
|
197
209
|
self.set_drop_input_cols(drop_input_cols)
|
198
|
-
self.set_sample_weight_col(sample_weight_col)
|
210
|
+
self.set_sample_weight_col(sample_weight_col)
|
211
|
+
self._use_external_memory_version = False
|
212
|
+
self._batch_size = -1
|
199
213
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
200
214
|
|
201
215
|
self._deps = list(deps)
|
@@ -277,11 +291,6 @@ class RidgeCV(BaseTransformer):
|
|
277
291
|
if isinstance(dataset, DataFrame):
|
278
292
|
session = dataset._session
|
279
293
|
assert session is not None # keep mypy happy
|
280
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
281
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
282
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
283
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
284
|
-
|
285
294
|
# Specify input columns so column pruning will be enforced
|
286
295
|
selected_cols = self._get_active_columns()
|
287
296
|
if len(selected_cols) > 0:
|
@@ -309,7 +318,9 @@ class RidgeCV(BaseTransformer):
|
|
309
318
|
label_cols=self.label_cols,
|
310
319
|
sample_weight_col=self.sample_weight_col,
|
311
320
|
autogenerated=self._autogenerated,
|
312
|
-
subproject=_SUBPROJECT
|
321
|
+
subproject=_SUBPROJECT,
|
322
|
+
use_external_memory_version=self._use_external_memory_version,
|
323
|
+
batch_size=self._batch_size,
|
313
324
|
)
|
314
325
|
self._sklearn_object = model_trainer.train()
|
315
326
|
self._is_fitted = True
|
@@ -580,6 +591,22 @@ class RidgeCV(BaseTransformer):
|
|
580
591
|
# each row containing a list of values.
|
581
592
|
expected_dtype = "ARRAY"
|
582
593
|
|
594
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
595
|
+
if expected_dtype == "":
|
596
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
597
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
598
|
+
expected_dtype = "ARRAY"
|
599
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
600
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
601
|
+
expected_dtype = "ARRAY"
|
602
|
+
else:
|
603
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
604
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
605
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
606
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
607
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
608
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
609
|
+
|
583
610
|
output_df = self._batch_inference(
|
584
611
|
dataset=dataset,
|
585
612
|
inference_method="transform",
|
@@ -595,8 +622,8 @@ class RidgeCV(BaseTransformer):
|
|
595
622
|
|
596
623
|
return output_df
|
597
624
|
|
598
|
-
@available_if(
|
599
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
625
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
626
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
600
627
|
""" Method not supported for this class.
|
601
628
|
|
602
629
|
|
@@ -609,13 +636,21 @@ class RidgeCV(BaseTransformer):
|
|
609
636
|
Returns:
|
610
637
|
Predicted dataset.
|
611
638
|
"""
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
639
|
+
self.fit(dataset)
|
640
|
+
assert self._sklearn_object is not None
|
641
|
+
return self._sklearn_object.labels_
|
642
|
+
|
643
|
+
|
644
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
645
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
646
|
+
"""
|
647
|
+
Returns:
|
648
|
+
Transformed dataset.
|
649
|
+
"""
|
650
|
+
self.fit(dataset)
|
651
|
+
assert self._sklearn_object is not None
|
652
|
+
return self._sklearn_object.embedding_
|
653
|
+
|
619
654
|
|
620
655
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
621
656
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SGDClassifier(BaseTransformer):
|
58
70
|
r"""Linear classifiers (SVM, logistic regression, etc
|
59
71
|
For more details on this class, see [sklearn.linear_model.SGDClassifier]
|
@@ -300,7 +312,9 @@ class SGDClassifier(BaseTransformer):
|
|
300
312
|
self.set_label_cols(label_cols)
|
301
313
|
self.set_passthrough_cols(passthrough_cols)
|
302
314
|
self.set_drop_input_cols(drop_input_cols)
|
303
|
-
self.set_sample_weight_col(sample_weight_col)
|
315
|
+
self.set_sample_weight_col(sample_weight_col)
|
316
|
+
self._use_external_memory_version = False
|
317
|
+
self._batch_size = -1
|
304
318
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
305
319
|
|
306
320
|
self._deps = list(deps)
|
@@ -396,11 +410,6 @@ class SGDClassifier(BaseTransformer):
|
|
396
410
|
if isinstance(dataset, DataFrame):
|
397
411
|
session = dataset._session
|
398
412
|
assert session is not None # keep mypy happy
|
399
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
400
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
401
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
402
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
403
|
-
|
404
413
|
# Specify input columns so column pruning will be enforced
|
405
414
|
selected_cols = self._get_active_columns()
|
406
415
|
if len(selected_cols) > 0:
|
@@ -428,7 +437,9 @@ class SGDClassifier(BaseTransformer):
|
|
428
437
|
label_cols=self.label_cols,
|
429
438
|
sample_weight_col=self.sample_weight_col,
|
430
439
|
autogenerated=self._autogenerated,
|
431
|
-
subproject=_SUBPROJECT
|
440
|
+
subproject=_SUBPROJECT,
|
441
|
+
use_external_memory_version=self._use_external_memory_version,
|
442
|
+
batch_size=self._batch_size,
|
432
443
|
)
|
433
444
|
self._sklearn_object = model_trainer.train()
|
434
445
|
self._is_fitted = True
|
@@ -699,6 +710,22 @@ class SGDClassifier(BaseTransformer):
|
|
699
710
|
# each row containing a list of values.
|
700
711
|
expected_dtype = "ARRAY"
|
701
712
|
|
713
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
714
|
+
if expected_dtype == "":
|
715
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
716
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
717
|
+
expected_dtype = "ARRAY"
|
718
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
719
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
720
|
+
expected_dtype = "ARRAY"
|
721
|
+
else:
|
722
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
723
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
724
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
725
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
726
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
727
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
728
|
+
|
702
729
|
output_df = self._batch_inference(
|
703
730
|
dataset=dataset,
|
704
731
|
inference_method="transform",
|
@@ -714,8 +741,8 @@ class SGDClassifier(BaseTransformer):
|
|
714
741
|
|
715
742
|
return output_df
|
716
743
|
|
717
|
-
@available_if(
|
718
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
744
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
745
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
719
746
|
""" Method not supported for this class.
|
720
747
|
|
721
748
|
|
@@ -728,13 +755,21 @@ class SGDClassifier(BaseTransformer):
|
|
728
755
|
Returns:
|
729
756
|
Predicted dataset.
|
730
757
|
"""
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
758
|
+
self.fit(dataset)
|
759
|
+
assert self._sklearn_object is not None
|
760
|
+
return self._sklearn_object.labels_
|
761
|
+
|
762
|
+
|
763
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
764
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
765
|
+
"""
|
766
|
+
Returns:
|
767
|
+
Transformed dataset.
|
768
|
+
"""
|
769
|
+
self.fit(dataset)
|
770
|
+
assert self._sklearn_object is not None
|
771
|
+
return self._sklearn_object.embedding_
|
772
|
+
|
738
773
|
|
739
774
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
740
775
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SGDOneClassSVM(BaseTransformer):
|
58
70
|
r"""Solves linear One-Class SVM using Stochastic Gradient Descent
|
59
71
|
For more details on this class, see [sklearn.linear_model.SGDOneClassSVM]
|
@@ -207,7 +219,9 @@ class SGDOneClassSVM(BaseTransformer):
|
|
207
219
|
self.set_label_cols(label_cols)
|
208
220
|
self.set_passthrough_cols(passthrough_cols)
|
209
221
|
self.set_drop_input_cols(drop_input_cols)
|
210
|
-
self.set_sample_weight_col(sample_weight_col)
|
222
|
+
self.set_sample_weight_col(sample_weight_col)
|
223
|
+
self._use_external_memory_version = False
|
224
|
+
self._batch_size = -1
|
211
225
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
212
226
|
|
213
227
|
self._deps = list(deps)
|
@@ -294,11 +308,6 @@ class SGDOneClassSVM(BaseTransformer):
|
|
294
308
|
if isinstance(dataset, DataFrame):
|
295
309
|
session = dataset._session
|
296
310
|
assert session is not None # keep mypy happy
|
297
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
298
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
299
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
300
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
301
|
-
|
302
311
|
# Specify input columns so column pruning will be enforced
|
303
312
|
selected_cols = self._get_active_columns()
|
304
313
|
if len(selected_cols) > 0:
|
@@ -326,7 +335,9 @@ class SGDOneClassSVM(BaseTransformer):
|
|
326
335
|
label_cols=self.label_cols,
|
327
336
|
sample_weight_col=self.sample_weight_col,
|
328
337
|
autogenerated=self._autogenerated,
|
329
|
-
subproject=_SUBPROJECT
|
338
|
+
subproject=_SUBPROJECT,
|
339
|
+
use_external_memory_version=self._use_external_memory_version,
|
340
|
+
batch_size=self._batch_size,
|
330
341
|
)
|
331
342
|
self._sklearn_object = model_trainer.train()
|
332
343
|
self._is_fitted = True
|
@@ -597,6 +608,22 @@ class SGDOneClassSVM(BaseTransformer):
|
|
597
608
|
# each row containing a list of values.
|
598
609
|
expected_dtype = "ARRAY"
|
599
610
|
|
611
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
612
|
+
if expected_dtype == "":
|
613
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
614
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
615
|
+
expected_dtype = "ARRAY"
|
616
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
617
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
618
|
+
expected_dtype = "ARRAY"
|
619
|
+
else:
|
620
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
621
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
622
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
623
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
624
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
625
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
626
|
+
|
600
627
|
output_df = self._batch_inference(
|
601
628
|
dataset=dataset,
|
602
629
|
inference_method="transform",
|
@@ -612,8 +639,8 @@ class SGDOneClassSVM(BaseTransformer):
|
|
612
639
|
|
613
640
|
return output_df
|
614
641
|
|
615
|
-
@available_if(
|
616
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
642
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
643
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
617
644
|
""" Perform fit on X and returns labels for X
|
618
645
|
For more details on this function, see [sklearn.linear_model.SGDOneClassSVM.fit_predict]
|
619
646
|
(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDOneClassSVM.html#sklearn.linear_model.SGDOneClassSVM.fit_predict)
|
@@ -628,13 +655,21 @@ class SGDOneClassSVM(BaseTransformer):
|
|
628
655
|
Returns:
|
629
656
|
Predicted dataset.
|
630
657
|
"""
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
658
|
+
self.fit(dataset)
|
659
|
+
assert self._sklearn_object is not None
|
660
|
+
return self._sklearn_object.labels_
|
661
|
+
|
662
|
+
|
663
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
664
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
665
|
+
"""
|
666
|
+
Returns:
|
667
|
+
Transformed dataset.
|
668
|
+
"""
|
669
|
+
self.fit(dataset)
|
670
|
+
assert self._sklearn_object is not None
|
671
|
+
return self._sklearn_object.embedding_
|
672
|
+
|
638
673
|
|
639
674
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
640
675
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SGDRegressor(BaseTransformer):
|
58
70
|
r"""Linear model fitted by minimizing a regularized empirical loss with SGD
|
59
71
|
For more details on this class, see [sklearn.linear_model.SGDRegressor]
|
@@ -268,7 +280,9 @@ class SGDRegressor(BaseTransformer):
|
|
268
280
|
self.set_label_cols(label_cols)
|
269
281
|
self.set_passthrough_cols(passthrough_cols)
|
270
282
|
self.set_drop_input_cols(drop_input_cols)
|
271
|
-
self.set_sample_weight_col(sample_weight_col)
|
283
|
+
self.set_sample_weight_col(sample_weight_col)
|
284
|
+
self._use_external_memory_version = False
|
285
|
+
self._batch_size = -1
|
272
286
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
273
287
|
|
274
288
|
self._deps = list(deps)
|
@@ -362,11 +376,6 @@ class SGDRegressor(BaseTransformer):
|
|
362
376
|
if isinstance(dataset, DataFrame):
|
363
377
|
session = dataset._session
|
364
378
|
assert session is not None # keep mypy happy
|
365
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
366
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
367
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
368
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
369
|
-
|
370
379
|
# Specify input columns so column pruning will be enforced
|
371
380
|
selected_cols = self._get_active_columns()
|
372
381
|
if len(selected_cols) > 0:
|
@@ -394,7 +403,9 @@ class SGDRegressor(BaseTransformer):
|
|
394
403
|
label_cols=self.label_cols,
|
395
404
|
sample_weight_col=self.sample_weight_col,
|
396
405
|
autogenerated=self._autogenerated,
|
397
|
-
subproject=_SUBPROJECT
|
406
|
+
subproject=_SUBPROJECT,
|
407
|
+
use_external_memory_version=self._use_external_memory_version,
|
408
|
+
batch_size=self._batch_size,
|
398
409
|
)
|
399
410
|
self._sklearn_object = model_trainer.train()
|
400
411
|
self._is_fitted = True
|
@@ -665,6 +676,22 @@ class SGDRegressor(BaseTransformer):
|
|
665
676
|
# each row containing a list of values.
|
666
677
|
expected_dtype = "ARRAY"
|
667
678
|
|
679
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
680
|
+
if expected_dtype == "":
|
681
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
682
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
683
|
+
expected_dtype = "ARRAY"
|
684
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
685
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
686
|
+
expected_dtype = "ARRAY"
|
687
|
+
else:
|
688
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
689
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
690
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
691
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
692
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
693
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
694
|
+
|
668
695
|
output_df = self._batch_inference(
|
669
696
|
dataset=dataset,
|
670
697
|
inference_method="transform",
|
@@ -680,8 +707,8 @@ class SGDRegressor(BaseTransformer):
|
|
680
707
|
|
681
708
|
return output_df
|
682
709
|
|
683
|
-
@available_if(
|
684
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
710
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
711
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
685
712
|
""" Method not supported for this class.
|
686
713
|
|
687
714
|
|
@@ -694,13 +721,21 @@ class SGDRegressor(BaseTransformer):
|
|
694
721
|
Returns:
|
695
722
|
Predicted dataset.
|
696
723
|
"""
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
724
|
+
self.fit(dataset)
|
725
|
+
assert self._sklearn_object is not None
|
726
|
+
return self._sklearn_object.labels_
|
727
|
+
|
728
|
+
|
729
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
730
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
731
|
+
"""
|
732
|
+
Returns:
|
733
|
+
Transformed dataset.
|
734
|
+
"""
|
735
|
+
self.fit(dataset)
|
736
|
+
assert self._sklearn_object is not None
|
737
|
+
return self._sklearn_object.embedding_
|
738
|
+
|
704
739
|
|
705
740
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
706
741
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class TheilSenRegressor(BaseTransformer):
|
58
70
|
r"""Theil-Sen Estimator: robust multivariate regression model
|
59
71
|
For more details on this class, see [sklearn.linear_model.TheilSenRegressor]
|
@@ -180,7 +192,9 @@ class TheilSenRegressor(BaseTransformer):
|
|
180
192
|
self.set_label_cols(label_cols)
|
181
193
|
self.set_passthrough_cols(passthrough_cols)
|
182
194
|
self.set_drop_input_cols(drop_input_cols)
|
183
|
-
self.set_sample_weight_col(sample_weight_col)
|
195
|
+
self.set_sample_weight_col(sample_weight_col)
|
196
|
+
self._use_external_memory_version = False
|
197
|
+
self._batch_size = -1
|
184
198
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
185
199
|
|
186
200
|
self._deps = list(deps)
|
@@ -264,11 +278,6 @@ class TheilSenRegressor(BaseTransformer):
|
|
264
278
|
if isinstance(dataset, DataFrame):
|
265
279
|
session = dataset._session
|
266
280
|
assert session is not None # keep mypy happy
|
267
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
268
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
269
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
270
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
271
|
-
|
272
281
|
# Specify input columns so column pruning will be enforced
|
273
282
|
selected_cols = self._get_active_columns()
|
274
283
|
if len(selected_cols) > 0:
|
@@ -296,7 +305,9 @@ class TheilSenRegressor(BaseTransformer):
|
|
296
305
|
label_cols=self.label_cols,
|
297
306
|
sample_weight_col=self.sample_weight_col,
|
298
307
|
autogenerated=self._autogenerated,
|
299
|
-
subproject=_SUBPROJECT
|
308
|
+
subproject=_SUBPROJECT,
|
309
|
+
use_external_memory_version=self._use_external_memory_version,
|
310
|
+
batch_size=self._batch_size,
|
300
311
|
)
|
301
312
|
self._sklearn_object = model_trainer.train()
|
302
313
|
self._is_fitted = True
|
@@ -567,6 +578,22 @@ class TheilSenRegressor(BaseTransformer):
|
|
567
578
|
# each row containing a list of values.
|
568
579
|
expected_dtype = "ARRAY"
|
569
580
|
|
581
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
582
|
+
if expected_dtype == "":
|
583
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
584
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
585
|
+
expected_dtype = "ARRAY"
|
586
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
587
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
588
|
+
expected_dtype = "ARRAY"
|
589
|
+
else:
|
590
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
591
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
592
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
593
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
594
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
595
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
596
|
+
|
570
597
|
output_df = self._batch_inference(
|
571
598
|
dataset=dataset,
|
572
599
|
inference_method="transform",
|
@@ -582,8 +609,8 @@ class TheilSenRegressor(BaseTransformer):
|
|
582
609
|
|
583
610
|
return output_df
|
584
611
|
|
585
|
-
@available_if(
|
586
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
612
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
613
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
587
614
|
""" Method not supported for this class.
|
588
615
|
|
589
616
|
|
@@ -596,13 +623,21 @@ class TheilSenRegressor(BaseTransformer):
|
|
596
623
|
Returns:
|
597
624
|
Predicted dataset.
|
598
625
|
"""
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
626
|
+
self.fit(dataset)
|
627
|
+
assert self._sklearn_object is not None
|
628
|
+
return self._sklearn_object.labels_
|
629
|
+
|
630
|
+
|
631
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
632
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
633
|
+
"""
|
634
|
+
Returns:
|
635
|
+
Transformed dataset.
|
636
|
+
"""
|
637
|
+
self.fit(dataset)
|
638
|
+
assert self._sklearn_object is not None
|
639
|
+
return self._sklearn_object.embedding_
|
640
|
+
|
606
641
|
|
607
642
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
608
643
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|