snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MultiTaskLassoCV(BaseTransformer):
|
58
70
|
r"""Multi-task Lasso model trained with L1/L2 mixed-norm as regularizer
|
59
71
|
For more details on this class, see [sklearn.linear_model.MultiTaskLassoCV]
|
@@ -200,7 +212,9 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
200
212
|
self.set_label_cols(label_cols)
|
201
213
|
self.set_passthrough_cols(passthrough_cols)
|
202
214
|
self.set_drop_input_cols(drop_input_cols)
|
203
|
-
self.set_sample_weight_col(sample_weight_col)
|
215
|
+
self.set_sample_weight_col(sample_weight_col)
|
216
|
+
self._use_external_memory_version = False
|
217
|
+
self._batch_size = -1
|
204
218
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
205
219
|
|
206
220
|
self._deps = list(deps)
|
@@ -287,11 +301,6 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
287
301
|
if isinstance(dataset, DataFrame):
|
288
302
|
session = dataset._session
|
289
303
|
assert session is not None # keep mypy happy
|
290
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
291
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
292
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
293
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
294
|
-
|
295
304
|
# Specify input columns so column pruning will be enforced
|
296
305
|
selected_cols = self._get_active_columns()
|
297
306
|
if len(selected_cols) > 0:
|
@@ -319,7 +328,9 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
319
328
|
label_cols=self.label_cols,
|
320
329
|
sample_weight_col=self.sample_weight_col,
|
321
330
|
autogenerated=self._autogenerated,
|
322
|
-
subproject=_SUBPROJECT
|
331
|
+
subproject=_SUBPROJECT,
|
332
|
+
use_external_memory_version=self._use_external_memory_version,
|
333
|
+
batch_size=self._batch_size,
|
323
334
|
)
|
324
335
|
self._sklearn_object = model_trainer.train()
|
325
336
|
self._is_fitted = True
|
@@ -590,6 +601,22 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
590
601
|
# each row containing a list of values.
|
591
602
|
expected_dtype = "ARRAY"
|
592
603
|
|
604
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
605
|
+
if expected_dtype == "":
|
606
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
607
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
608
|
+
expected_dtype = "ARRAY"
|
609
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
610
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
611
|
+
expected_dtype = "ARRAY"
|
612
|
+
else:
|
613
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
614
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
615
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
616
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
617
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
618
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
619
|
+
|
593
620
|
output_df = self._batch_inference(
|
594
621
|
dataset=dataset,
|
595
622
|
inference_method="transform",
|
@@ -605,8 +632,8 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
605
632
|
|
606
633
|
return output_df
|
607
634
|
|
608
|
-
@available_if(
|
609
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
635
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
636
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
610
637
|
""" Method not supported for this class.
|
611
638
|
|
612
639
|
|
@@ -619,13 +646,21 @@ class MultiTaskLassoCV(BaseTransformer):
|
|
619
646
|
Returns:
|
620
647
|
Predicted dataset.
|
621
648
|
"""
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
649
|
+
self.fit(dataset)
|
650
|
+
assert self._sklearn_object is not None
|
651
|
+
return self._sklearn_object.labels_
|
652
|
+
|
653
|
+
|
654
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
655
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
656
|
+
"""
|
657
|
+
Returns:
|
658
|
+
Transformed dataset.
|
659
|
+
"""
|
660
|
+
self.fit(dataset)
|
661
|
+
assert self._sklearn_object is not None
|
662
|
+
return self._sklearn_object.embedding_
|
663
|
+
|
629
664
|
|
630
665
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
631
666
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class OrthogonalMatchingPursuit(BaseTransformer):
|
58
70
|
r"""Orthogonal Matching Pursuit model (OMP)
|
59
71
|
For more details on this class, see [sklearn.linear_model.OrthogonalMatchingPursuit]
|
@@ -155,7 +167,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
155
167
|
self.set_label_cols(label_cols)
|
156
168
|
self.set_passthrough_cols(passthrough_cols)
|
157
169
|
self.set_drop_input_cols(drop_input_cols)
|
158
|
-
self.set_sample_weight_col(sample_weight_col)
|
170
|
+
self.set_sample_weight_col(sample_weight_col)
|
171
|
+
self._use_external_memory_version = False
|
172
|
+
self._batch_size = -1
|
159
173
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
160
174
|
|
161
175
|
self._deps = list(deps)
|
@@ -235,11 +249,6 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
235
249
|
if isinstance(dataset, DataFrame):
|
236
250
|
session = dataset._session
|
237
251
|
assert session is not None # keep mypy happy
|
238
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
239
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
240
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
241
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
242
|
-
|
243
252
|
# Specify input columns so column pruning will be enforced
|
244
253
|
selected_cols = self._get_active_columns()
|
245
254
|
if len(selected_cols) > 0:
|
@@ -267,7 +276,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
267
276
|
label_cols=self.label_cols,
|
268
277
|
sample_weight_col=self.sample_weight_col,
|
269
278
|
autogenerated=self._autogenerated,
|
270
|
-
subproject=_SUBPROJECT
|
279
|
+
subproject=_SUBPROJECT,
|
280
|
+
use_external_memory_version=self._use_external_memory_version,
|
281
|
+
batch_size=self._batch_size,
|
271
282
|
)
|
272
283
|
self._sklearn_object = model_trainer.train()
|
273
284
|
self._is_fitted = True
|
@@ -538,6 +549,22 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
538
549
|
# each row containing a list of values.
|
539
550
|
expected_dtype = "ARRAY"
|
540
551
|
|
552
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
553
|
+
if expected_dtype == "":
|
554
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
555
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
556
|
+
expected_dtype = "ARRAY"
|
557
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
558
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
559
|
+
expected_dtype = "ARRAY"
|
560
|
+
else:
|
561
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
562
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
563
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
564
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
565
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
566
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
567
|
+
|
541
568
|
output_df = self._batch_inference(
|
542
569
|
dataset=dataset,
|
543
570
|
inference_method="transform",
|
@@ -553,8 +580,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
553
580
|
|
554
581
|
return output_df
|
555
582
|
|
556
|
-
@available_if(
|
557
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
583
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
584
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
558
585
|
""" Method not supported for this class.
|
559
586
|
|
560
587
|
|
@@ -567,13 +594,21 @@ class OrthogonalMatchingPursuit(BaseTransformer):
|
|
567
594
|
Returns:
|
568
595
|
Predicted dataset.
|
569
596
|
"""
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
597
|
+
self.fit(dataset)
|
598
|
+
assert self._sklearn_object is not None
|
599
|
+
return self._sklearn_object.labels_
|
600
|
+
|
601
|
+
|
602
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
603
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
604
|
+
"""
|
605
|
+
Returns:
|
606
|
+
Transformed dataset.
|
607
|
+
"""
|
608
|
+
self.fit(dataset)
|
609
|
+
assert self._sklearn_object is not None
|
610
|
+
return self._sklearn_object.embedding_
|
611
|
+
|
577
612
|
|
578
613
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
579
614
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class PassiveAggressiveClassifier(BaseTransformer):
|
58
70
|
r"""Passive Aggressive Classifier
|
59
71
|
For more details on this class, see [sklearn.linear_model.PassiveAggressiveClassifier]
|
@@ -219,7 +231,9 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
219
231
|
self.set_label_cols(label_cols)
|
220
232
|
self.set_passthrough_cols(passthrough_cols)
|
221
233
|
self.set_drop_input_cols(drop_input_cols)
|
222
|
-
self.set_sample_weight_col(sample_weight_col)
|
234
|
+
self.set_sample_weight_col(sample_weight_col)
|
235
|
+
self._use_external_memory_version = False
|
236
|
+
self._batch_size = -1
|
223
237
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
224
238
|
|
225
239
|
self._deps = list(deps)
|
@@ -309,11 +323,6 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
309
323
|
if isinstance(dataset, DataFrame):
|
310
324
|
session = dataset._session
|
311
325
|
assert session is not None # keep mypy happy
|
312
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
313
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
314
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
315
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
316
|
-
|
317
326
|
# Specify input columns so column pruning will be enforced
|
318
327
|
selected_cols = self._get_active_columns()
|
319
328
|
if len(selected_cols) > 0:
|
@@ -341,7 +350,9 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
341
350
|
label_cols=self.label_cols,
|
342
351
|
sample_weight_col=self.sample_weight_col,
|
343
352
|
autogenerated=self._autogenerated,
|
344
|
-
subproject=_SUBPROJECT
|
353
|
+
subproject=_SUBPROJECT,
|
354
|
+
use_external_memory_version=self._use_external_memory_version,
|
355
|
+
batch_size=self._batch_size,
|
345
356
|
)
|
346
357
|
self._sklearn_object = model_trainer.train()
|
347
358
|
self._is_fitted = True
|
@@ -612,6 +623,22 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
612
623
|
# each row containing a list of values.
|
613
624
|
expected_dtype = "ARRAY"
|
614
625
|
|
626
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
627
|
+
if expected_dtype == "":
|
628
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
629
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
630
|
+
expected_dtype = "ARRAY"
|
631
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
632
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
633
|
+
expected_dtype = "ARRAY"
|
634
|
+
else:
|
635
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
636
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
637
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
638
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
639
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
640
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
641
|
+
|
615
642
|
output_df = self._batch_inference(
|
616
643
|
dataset=dataset,
|
617
644
|
inference_method="transform",
|
@@ -627,8 +654,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
627
654
|
|
628
655
|
return output_df
|
629
656
|
|
630
|
-
@available_if(
|
631
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
657
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
658
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
632
659
|
""" Method not supported for this class.
|
633
660
|
|
634
661
|
|
@@ -641,13 +668,21 @@ class PassiveAggressiveClassifier(BaseTransformer):
|
|
641
668
|
Returns:
|
642
669
|
Predicted dataset.
|
643
670
|
"""
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
671
|
+
self.fit(dataset)
|
672
|
+
assert self._sklearn_object is not None
|
673
|
+
return self._sklearn_object.labels_
|
674
|
+
|
675
|
+
|
676
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
677
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
678
|
+
"""
|
679
|
+
Returns:
|
680
|
+
Transformed dataset.
|
681
|
+
"""
|
682
|
+
self.fit(dataset)
|
683
|
+
assert self._sklearn_object is not None
|
684
|
+
return self._sklearn_object.embedding_
|
685
|
+
|
651
686
|
|
652
687
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
653
688
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class PassiveAggressiveRegressor(BaseTransformer):
|
58
70
|
r"""Passive Aggressive Regressor
|
59
71
|
For more details on this class, see [sklearn.linear_model.PassiveAggressiveRegressor]
|
@@ -206,7 +218,9 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
206
218
|
self.set_label_cols(label_cols)
|
207
219
|
self.set_passthrough_cols(passthrough_cols)
|
208
220
|
self.set_drop_input_cols(drop_input_cols)
|
209
|
-
self.set_sample_weight_col(sample_weight_col)
|
221
|
+
self.set_sample_weight_col(sample_weight_col)
|
222
|
+
self._use_external_memory_version = False
|
223
|
+
self._batch_size = -1
|
210
224
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
211
225
|
|
212
226
|
self._deps = list(deps)
|
@@ -295,11 +309,6 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
295
309
|
if isinstance(dataset, DataFrame):
|
296
310
|
session = dataset._session
|
297
311
|
assert session is not None # keep mypy happy
|
298
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
299
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
300
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
301
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
302
|
-
|
303
312
|
# Specify input columns so column pruning will be enforced
|
304
313
|
selected_cols = self._get_active_columns()
|
305
314
|
if len(selected_cols) > 0:
|
@@ -327,7 +336,9 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
327
336
|
label_cols=self.label_cols,
|
328
337
|
sample_weight_col=self.sample_weight_col,
|
329
338
|
autogenerated=self._autogenerated,
|
330
|
-
subproject=_SUBPROJECT
|
339
|
+
subproject=_SUBPROJECT,
|
340
|
+
use_external_memory_version=self._use_external_memory_version,
|
341
|
+
batch_size=self._batch_size,
|
331
342
|
)
|
332
343
|
self._sklearn_object = model_trainer.train()
|
333
344
|
self._is_fitted = True
|
@@ -598,6 +609,22 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
598
609
|
# each row containing a list of values.
|
599
610
|
expected_dtype = "ARRAY"
|
600
611
|
|
612
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
613
|
+
if expected_dtype == "":
|
614
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
615
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
616
|
+
expected_dtype = "ARRAY"
|
617
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
618
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
619
|
+
expected_dtype = "ARRAY"
|
620
|
+
else:
|
621
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
622
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
623
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
624
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
625
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
626
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
627
|
+
|
601
628
|
output_df = self._batch_inference(
|
602
629
|
dataset=dataset,
|
603
630
|
inference_method="transform",
|
@@ -613,8 +640,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
613
640
|
|
614
641
|
return output_df
|
615
642
|
|
616
|
-
@available_if(
|
617
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
643
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
644
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
618
645
|
""" Method not supported for this class.
|
619
646
|
|
620
647
|
|
@@ -627,13 +654,21 @@ class PassiveAggressiveRegressor(BaseTransformer):
|
|
627
654
|
Returns:
|
628
655
|
Predicted dataset.
|
629
656
|
"""
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
657
|
+
self.fit(dataset)
|
658
|
+
assert self._sklearn_object is not None
|
659
|
+
return self._sklearn_object.labels_
|
660
|
+
|
661
|
+
|
662
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
663
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
664
|
+
"""
|
665
|
+
Returns:
|
666
|
+
Transformed dataset.
|
667
|
+
"""
|
668
|
+
self.fit(dataset)
|
669
|
+
assert self._sklearn_object is not None
|
670
|
+
return self._sklearn_object.embedding_
|
671
|
+
|
637
672
|
|
638
673
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
639
674
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class Perceptron(BaseTransformer):
|
58
70
|
r"""Linear perceptron classifier
|
59
71
|
For more details on this class, see [sklearn.linear_model.Perceptron]
|
@@ -217,7 +229,9 @@ class Perceptron(BaseTransformer):
|
|
217
229
|
self.set_label_cols(label_cols)
|
218
230
|
self.set_passthrough_cols(passthrough_cols)
|
219
231
|
self.set_drop_input_cols(drop_input_cols)
|
220
|
-
self.set_sample_weight_col(sample_weight_col)
|
232
|
+
self.set_sample_weight_col(sample_weight_col)
|
233
|
+
self._use_external_memory_version = False
|
234
|
+
self._batch_size = -1
|
221
235
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
222
236
|
|
223
237
|
self._deps = list(deps)
|
@@ -308,11 +322,6 @@ class Perceptron(BaseTransformer):
|
|
308
322
|
if isinstance(dataset, DataFrame):
|
309
323
|
session = dataset._session
|
310
324
|
assert session is not None # keep mypy happy
|
311
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
312
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
313
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
314
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
315
|
-
|
316
325
|
# Specify input columns so column pruning will be enforced
|
317
326
|
selected_cols = self._get_active_columns()
|
318
327
|
if len(selected_cols) > 0:
|
@@ -340,7 +349,9 @@ class Perceptron(BaseTransformer):
|
|
340
349
|
label_cols=self.label_cols,
|
341
350
|
sample_weight_col=self.sample_weight_col,
|
342
351
|
autogenerated=self._autogenerated,
|
343
|
-
subproject=_SUBPROJECT
|
352
|
+
subproject=_SUBPROJECT,
|
353
|
+
use_external_memory_version=self._use_external_memory_version,
|
354
|
+
batch_size=self._batch_size,
|
344
355
|
)
|
345
356
|
self._sklearn_object = model_trainer.train()
|
346
357
|
self._is_fitted = True
|
@@ -611,6 +622,22 @@ class Perceptron(BaseTransformer):
|
|
611
622
|
# each row containing a list of values.
|
612
623
|
expected_dtype = "ARRAY"
|
613
624
|
|
625
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
626
|
+
if expected_dtype == "":
|
627
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
628
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
629
|
+
expected_dtype = "ARRAY"
|
630
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
631
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
632
|
+
expected_dtype = "ARRAY"
|
633
|
+
else:
|
634
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
635
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
636
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
637
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
638
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
639
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
640
|
+
|
614
641
|
output_df = self._batch_inference(
|
615
642
|
dataset=dataset,
|
616
643
|
inference_method="transform",
|
@@ -626,8 +653,8 @@ class Perceptron(BaseTransformer):
|
|
626
653
|
|
627
654
|
return output_df
|
628
655
|
|
629
|
-
@available_if(
|
630
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
656
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
657
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
631
658
|
""" Method not supported for this class.
|
632
659
|
|
633
660
|
|
@@ -640,13 +667,21 @@ class Perceptron(BaseTransformer):
|
|
640
667
|
Returns:
|
641
668
|
Predicted dataset.
|
642
669
|
"""
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
670
|
+
self.fit(dataset)
|
671
|
+
assert self._sklearn_object is not None
|
672
|
+
return self._sklearn_object.labels_
|
673
|
+
|
674
|
+
|
675
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
676
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
677
|
+
"""
|
678
|
+
Returns:
|
679
|
+
Transformed dataset.
|
680
|
+
"""
|
681
|
+
self.fit(dataset)
|
682
|
+
assert self._sklearn_object is not None
|
683
|
+
return self._sklearn_object.embedding_
|
684
|
+
|
650
685
|
|
651
686
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
652
687
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|