snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LinearSVC(BaseTransformer):
|
58
70
|
r"""Linear Support Vector Classification
|
59
71
|
For more details on this class, see [sklearn.svm.LinearSVC]
|
@@ -214,7 +226,9 @@ class LinearSVC(BaseTransformer):
|
|
214
226
|
self.set_label_cols(label_cols)
|
215
227
|
self.set_passthrough_cols(passthrough_cols)
|
216
228
|
self.set_drop_input_cols(drop_input_cols)
|
217
|
-
self.set_sample_weight_col(sample_weight_col)
|
229
|
+
self.set_sample_weight_col(sample_weight_col)
|
230
|
+
self._use_external_memory_version = False
|
231
|
+
self._batch_size = -1
|
218
232
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
219
233
|
|
220
234
|
self._deps = list(deps)
|
@@ -301,11 +315,6 @@ class LinearSVC(BaseTransformer):
|
|
301
315
|
if isinstance(dataset, DataFrame):
|
302
316
|
session = dataset._session
|
303
317
|
assert session is not None # keep mypy happy
|
304
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
305
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
306
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
307
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
308
|
-
|
309
318
|
# Specify input columns so column pruning will be enforced
|
310
319
|
selected_cols = self._get_active_columns()
|
311
320
|
if len(selected_cols) > 0:
|
@@ -333,7 +342,9 @@ class LinearSVC(BaseTransformer):
|
|
333
342
|
label_cols=self.label_cols,
|
334
343
|
sample_weight_col=self.sample_weight_col,
|
335
344
|
autogenerated=self._autogenerated,
|
336
|
-
subproject=_SUBPROJECT
|
345
|
+
subproject=_SUBPROJECT,
|
346
|
+
use_external_memory_version=self._use_external_memory_version,
|
347
|
+
batch_size=self._batch_size,
|
337
348
|
)
|
338
349
|
self._sklearn_object = model_trainer.train()
|
339
350
|
self._is_fitted = True
|
@@ -604,6 +615,22 @@ class LinearSVC(BaseTransformer):
|
|
604
615
|
# each row containing a list of values.
|
605
616
|
expected_dtype = "ARRAY"
|
606
617
|
|
618
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
619
|
+
if expected_dtype == "":
|
620
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
621
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
622
|
+
expected_dtype = "ARRAY"
|
623
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
624
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
625
|
+
expected_dtype = "ARRAY"
|
626
|
+
else:
|
627
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
628
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
629
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
630
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
631
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
632
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
633
|
+
|
607
634
|
output_df = self._batch_inference(
|
608
635
|
dataset=dataset,
|
609
636
|
inference_method="transform",
|
@@ -619,8 +646,8 @@ class LinearSVC(BaseTransformer):
|
|
619
646
|
|
620
647
|
return output_df
|
621
648
|
|
622
|
-
@available_if(
|
623
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
649
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
650
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
624
651
|
""" Method not supported for this class.
|
625
652
|
|
626
653
|
|
@@ -633,13 +660,21 @@ class LinearSVC(BaseTransformer):
|
|
633
660
|
Returns:
|
634
661
|
Predicted dataset.
|
635
662
|
"""
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
663
|
+
self.fit(dataset)
|
664
|
+
assert self._sklearn_object is not None
|
665
|
+
return self._sklearn_object.labels_
|
666
|
+
|
667
|
+
|
668
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
669
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
670
|
+
"""
|
671
|
+
Returns:
|
672
|
+
Transformed dataset.
|
673
|
+
"""
|
674
|
+
self.fit(dataset)
|
675
|
+
assert self._sklearn_object is not None
|
676
|
+
return self._sklearn_object.embedding_
|
677
|
+
|
643
678
|
|
644
679
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
645
680
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class LinearSVR(BaseTransformer):
|
58
70
|
r"""Linear Support Vector Regression
|
59
71
|
For more details on this class, see [sklearn.svm.LinearSVR]
|
@@ -188,7 +200,9 @@ class LinearSVR(BaseTransformer):
|
|
188
200
|
self.set_label_cols(label_cols)
|
189
201
|
self.set_passthrough_cols(passthrough_cols)
|
190
202
|
self.set_drop_input_cols(drop_input_cols)
|
191
|
-
self.set_sample_weight_col(sample_weight_col)
|
203
|
+
self.set_sample_weight_col(sample_weight_col)
|
204
|
+
self._use_external_memory_version = False
|
205
|
+
self._batch_size = -1
|
192
206
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
193
207
|
|
194
208
|
self._deps = list(deps)
|
@@ -273,11 +287,6 @@ class LinearSVR(BaseTransformer):
|
|
273
287
|
if isinstance(dataset, DataFrame):
|
274
288
|
session = dataset._session
|
275
289
|
assert session is not None # keep mypy happy
|
276
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
277
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
278
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
279
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
280
|
-
|
281
290
|
# Specify input columns so column pruning will be enforced
|
282
291
|
selected_cols = self._get_active_columns()
|
283
292
|
if len(selected_cols) > 0:
|
@@ -305,7 +314,9 @@ class LinearSVR(BaseTransformer):
|
|
305
314
|
label_cols=self.label_cols,
|
306
315
|
sample_weight_col=self.sample_weight_col,
|
307
316
|
autogenerated=self._autogenerated,
|
308
|
-
subproject=_SUBPROJECT
|
317
|
+
subproject=_SUBPROJECT,
|
318
|
+
use_external_memory_version=self._use_external_memory_version,
|
319
|
+
batch_size=self._batch_size,
|
309
320
|
)
|
310
321
|
self._sklearn_object = model_trainer.train()
|
311
322
|
self._is_fitted = True
|
@@ -576,6 +587,22 @@ class LinearSVR(BaseTransformer):
|
|
576
587
|
# each row containing a list of values.
|
577
588
|
expected_dtype = "ARRAY"
|
578
589
|
|
590
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
591
|
+
if expected_dtype == "":
|
592
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
593
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
594
|
+
expected_dtype = "ARRAY"
|
595
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
596
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
597
|
+
expected_dtype = "ARRAY"
|
598
|
+
else:
|
599
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
600
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
601
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
602
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
603
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
604
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
605
|
+
|
579
606
|
output_df = self._batch_inference(
|
580
607
|
dataset=dataset,
|
581
608
|
inference_method="transform",
|
@@ -591,8 +618,8 @@ class LinearSVR(BaseTransformer):
|
|
591
618
|
|
592
619
|
return output_df
|
593
620
|
|
594
|
-
@available_if(
|
595
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
621
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
622
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
596
623
|
""" Method not supported for this class.
|
597
624
|
|
598
625
|
|
@@ -605,13 +632,21 @@ class LinearSVR(BaseTransformer):
|
|
605
632
|
Returns:
|
606
633
|
Predicted dataset.
|
607
634
|
"""
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
635
|
+
self.fit(dataset)
|
636
|
+
assert self._sklearn_object is not None
|
637
|
+
return self._sklearn_object.labels_
|
638
|
+
|
639
|
+
|
640
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
641
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
642
|
+
"""
|
643
|
+
Returns:
|
644
|
+
Transformed dataset.
|
645
|
+
"""
|
646
|
+
self.fit(dataset)
|
647
|
+
assert self._sklearn_object is not None
|
648
|
+
return self._sklearn_object.embedding_
|
649
|
+
|
615
650
|
|
616
651
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
617
652
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class NuSVC(BaseTransformer):
|
58
70
|
r"""Nu-Support Vector Classification
|
59
71
|
For more details on this class, see [sklearn.svm.NuSVC]
|
@@ -217,7 +229,9 @@ class NuSVC(BaseTransformer):
|
|
217
229
|
self.set_label_cols(label_cols)
|
218
230
|
self.set_passthrough_cols(passthrough_cols)
|
219
231
|
self.set_drop_input_cols(drop_input_cols)
|
220
|
-
self.set_sample_weight_col(sample_weight_col)
|
232
|
+
self.set_sample_weight_col(sample_weight_col)
|
233
|
+
self._use_external_memory_version = False
|
234
|
+
self._batch_size = -1
|
221
235
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
222
236
|
|
223
237
|
self._deps = list(deps)
|
@@ -307,11 +321,6 @@ class NuSVC(BaseTransformer):
|
|
307
321
|
if isinstance(dataset, DataFrame):
|
308
322
|
session = dataset._session
|
309
323
|
assert session is not None # keep mypy happy
|
310
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
311
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
312
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
313
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
314
|
-
|
315
324
|
# Specify input columns so column pruning will be enforced
|
316
325
|
selected_cols = self._get_active_columns()
|
317
326
|
if len(selected_cols) > 0:
|
@@ -339,7 +348,9 @@ class NuSVC(BaseTransformer):
|
|
339
348
|
label_cols=self.label_cols,
|
340
349
|
sample_weight_col=self.sample_weight_col,
|
341
350
|
autogenerated=self._autogenerated,
|
342
|
-
subproject=_SUBPROJECT
|
351
|
+
subproject=_SUBPROJECT,
|
352
|
+
use_external_memory_version=self._use_external_memory_version,
|
353
|
+
batch_size=self._batch_size,
|
343
354
|
)
|
344
355
|
self._sklearn_object = model_trainer.train()
|
345
356
|
self._is_fitted = True
|
@@ -610,6 +621,22 @@ class NuSVC(BaseTransformer):
|
|
610
621
|
# each row containing a list of values.
|
611
622
|
expected_dtype = "ARRAY"
|
612
623
|
|
624
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
625
|
+
if expected_dtype == "":
|
626
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
627
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
628
|
+
expected_dtype = "ARRAY"
|
629
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
630
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
631
|
+
expected_dtype = "ARRAY"
|
632
|
+
else:
|
633
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
634
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
635
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
636
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
637
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
638
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
639
|
+
|
613
640
|
output_df = self._batch_inference(
|
614
641
|
dataset=dataset,
|
615
642
|
inference_method="transform",
|
@@ -625,8 +652,8 @@ class NuSVC(BaseTransformer):
|
|
625
652
|
|
626
653
|
return output_df
|
627
654
|
|
628
|
-
@available_if(
|
629
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
655
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
656
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
630
657
|
""" Method not supported for this class.
|
631
658
|
|
632
659
|
|
@@ -639,13 +666,21 @@ class NuSVC(BaseTransformer):
|
|
639
666
|
Returns:
|
640
667
|
Predicted dataset.
|
641
668
|
"""
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
669
|
+
self.fit(dataset)
|
670
|
+
assert self._sklearn_object is not None
|
671
|
+
return self._sklearn_object.labels_
|
672
|
+
|
673
|
+
|
674
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
675
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
676
|
+
"""
|
677
|
+
Returns:
|
678
|
+
Transformed dataset.
|
679
|
+
"""
|
680
|
+
self.fit(dataset)
|
681
|
+
assert self._sklearn_object is not None
|
682
|
+
return self._sklearn_object.embedding_
|
683
|
+
|
649
684
|
|
650
685
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
651
686
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class NuSVR(BaseTransformer):
|
58
70
|
r"""Nu Support Vector Regression
|
59
71
|
For more details on this class, see [sklearn.svm.NuSVR]
|
@@ -182,7 +194,9 @@ class NuSVR(BaseTransformer):
|
|
182
194
|
self.set_label_cols(label_cols)
|
183
195
|
self.set_passthrough_cols(passthrough_cols)
|
184
196
|
self.set_drop_input_cols(drop_input_cols)
|
185
|
-
self.set_sample_weight_col(sample_weight_col)
|
197
|
+
self.set_sample_weight_col(sample_weight_col)
|
198
|
+
self._use_external_memory_version = False
|
199
|
+
self._batch_size = -1
|
186
200
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
187
201
|
|
188
202
|
self._deps = list(deps)
|
@@ -268,11 +282,6 @@ class NuSVR(BaseTransformer):
|
|
268
282
|
if isinstance(dataset, DataFrame):
|
269
283
|
session = dataset._session
|
270
284
|
assert session is not None # keep mypy happy
|
271
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
272
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
273
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
274
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
275
|
-
|
276
285
|
# Specify input columns so column pruning will be enforced
|
277
286
|
selected_cols = self._get_active_columns()
|
278
287
|
if len(selected_cols) > 0:
|
@@ -300,7 +309,9 @@ class NuSVR(BaseTransformer):
|
|
300
309
|
label_cols=self.label_cols,
|
301
310
|
sample_weight_col=self.sample_weight_col,
|
302
311
|
autogenerated=self._autogenerated,
|
303
|
-
subproject=_SUBPROJECT
|
312
|
+
subproject=_SUBPROJECT,
|
313
|
+
use_external_memory_version=self._use_external_memory_version,
|
314
|
+
batch_size=self._batch_size,
|
304
315
|
)
|
305
316
|
self._sklearn_object = model_trainer.train()
|
306
317
|
self._is_fitted = True
|
@@ -571,6 +582,22 @@ class NuSVR(BaseTransformer):
|
|
571
582
|
# each row containing a list of values.
|
572
583
|
expected_dtype = "ARRAY"
|
573
584
|
|
585
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
586
|
+
if expected_dtype == "":
|
587
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
588
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
589
|
+
expected_dtype = "ARRAY"
|
590
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
591
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
592
|
+
expected_dtype = "ARRAY"
|
593
|
+
else:
|
594
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
595
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
596
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
597
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
598
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
599
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
600
|
+
|
574
601
|
output_df = self._batch_inference(
|
575
602
|
dataset=dataset,
|
576
603
|
inference_method="transform",
|
@@ -586,8 +613,8 @@ class NuSVR(BaseTransformer):
|
|
586
613
|
|
587
614
|
return output_df
|
588
615
|
|
589
|
-
@available_if(
|
590
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
616
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
617
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
591
618
|
""" Method not supported for this class.
|
592
619
|
|
593
620
|
|
@@ -600,13 +627,21 @@ class NuSVR(BaseTransformer):
|
|
600
627
|
Returns:
|
601
628
|
Predicted dataset.
|
602
629
|
"""
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
630
|
+
self.fit(dataset)
|
631
|
+
assert self._sklearn_object is not None
|
632
|
+
return self._sklearn_object.labels_
|
633
|
+
|
634
|
+
|
635
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
636
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
637
|
+
"""
|
638
|
+
Returns:
|
639
|
+
Transformed dataset.
|
640
|
+
"""
|
641
|
+
self.fit(dataset)
|
642
|
+
assert self._sklearn_object is not None
|
643
|
+
return self._sklearn_object.embedding_
|
644
|
+
|
610
645
|
|
611
646
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
612
647
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
snowflake/ml/modeling/svm/svc.py
CHANGED
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class SVC(BaseTransformer):
|
58
70
|
r"""C-Support Vector Classification
|
59
71
|
For more details on this class, see [sklearn.svm.SVC]
|
@@ -220,7 +232,9 @@ class SVC(BaseTransformer):
|
|
220
232
|
self.set_label_cols(label_cols)
|
221
233
|
self.set_passthrough_cols(passthrough_cols)
|
222
234
|
self.set_drop_input_cols(drop_input_cols)
|
223
|
-
self.set_sample_weight_col(sample_weight_col)
|
235
|
+
self.set_sample_weight_col(sample_weight_col)
|
236
|
+
self._use_external_memory_version = False
|
237
|
+
self._batch_size = -1
|
224
238
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
225
239
|
|
226
240
|
self._deps = list(deps)
|
@@ -310,11 +324,6 @@ class SVC(BaseTransformer):
|
|
310
324
|
if isinstance(dataset, DataFrame):
|
311
325
|
session = dataset._session
|
312
326
|
assert session is not None # keep mypy happy
|
313
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
314
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
315
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
316
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
317
|
-
|
318
327
|
# Specify input columns so column pruning will be enforced
|
319
328
|
selected_cols = self._get_active_columns()
|
320
329
|
if len(selected_cols) > 0:
|
@@ -342,7 +351,9 @@ class SVC(BaseTransformer):
|
|
342
351
|
label_cols=self.label_cols,
|
343
352
|
sample_weight_col=self.sample_weight_col,
|
344
353
|
autogenerated=self._autogenerated,
|
345
|
-
subproject=_SUBPROJECT
|
354
|
+
subproject=_SUBPROJECT,
|
355
|
+
use_external_memory_version=self._use_external_memory_version,
|
356
|
+
batch_size=self._batch_size,
|
346
357
|
)
|
347
358
|
self._sklearn_object = model_trainer.train()
|
348
359
|
self._is_fitted = True
|
@@ -613,6 +624,22 @@ class SVC(BaseTransformer):
|
|
613
624
|
# each row containing a list of values.
|
614
625
|
expected_dtype = "ARRAY"
|
615
626
|
|
627
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
628
|
+
if expected_dtype == "":
|
629
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
630
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
631
|
+
expected_dtype = "ARRAY"
|
632
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
633
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
634
|
+
expected_dtype = "ARRAY"
|
635
|
+
else:
|
636
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
637
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
638
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
639
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
640
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
641
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
642
|
+
|
616
643
|
output_df = self._batch_inference(
|
617
644
|
dataset=dataset,
|
618
645
|
inference_method="transform",
|
@@ -628,8 +655,8 @@ class SVC(BaseTransformer):
|
|
628
655
|
|
629
656
|
return output_df
|
630
657
|
|
631
|
-
@available_if(
|
632
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
658
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
659
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
633
660
|
""" Method not supported for this class.
|
634
661
|
|
635
662
|
|
@@ -642,13 +669,21 @@ class SVC(BaseTransformer):
|
|
642
669
|
Returns:
|
643
670
|
Predicted dataset.
|
644
671
|
"""
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
672
|
+
self.fit(dataset)
|
673
|
+
assert self._sklearn_object is not None
|
674
|
+
return self._sklearn_object.labels_
|
675
|
+
|
676
|
+
|
677
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
678
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
679
|
+
"""
|
680
|
+
Returns:
|
681
|
+
Transformed dataset.
|
682
|
+
"""
|
683
|
+
self.fit(dataset)
|
684
|
+
assert self._sklearn_object is not None
|
685
|
+
return self._sklearn_object.embedding_
|
686
|
+
|
652
687
|
|
653
688
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
654
689
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|