snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class BernoulliNB(BaseTransformer):
|
58
70
|
r"""Naive Bayes classifier for multivariate Bernoulli models
|
59
71
|
For more details on this class, see [sklearn.naive_bayes.BernoulliNB]
|
@@ -150,7 +162,9 @@ class BernoulliNB(BaseTransformer):
|
|
150
162
|
self.set_label_cols(label_cols)
|
151
163
|
self.set_passthrough_cols(passthrough_cols)
|
152
164
|
self.set_drop_input_cols(drop_input_cols)
|
153
|
-
self.set_sample_weight_col(sample_weight_col)
|
165
|
+
self.set_sample_weight_col(sample_weight_col)
|
166
|
+
self._use_external_memory_version = False
|
167
|
+
self._batch_size = -1
|
154
168
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
155
169
|
|
156
170
|
self._deps = list(deps)
|
@@ -230,11 +244,6 @@ class BernoulliNB(BaseTransformer):
|
|
230
244
|
if isinstance(dataset, DataFrame):
|
231
245
|
session = dataset._session
|
232
246
|
assert session is not None # keep mypy happy
|
233
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
234
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
235
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
236
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
237
|
-
|
238
247
|
# Specify input columns so column pruning will be enforced
|
239
248
|
selected_cols = self._get_active_columns()
|
240
249
|
if len(selected_cols) > 0:
|
@@ -262,7 +271,9 @@ class BernoulliNB(BaseTransformer):
|
|
262
271
|
label_cols=self.label_cols,
|
263
272
|
sample_weight_col=self.sample_weight_col,
|
264
273
|
autogenerated=self._autogenerated,
|
265
|
-
subproject=_SUBPROJECT
|
274
|
+
subproject=_SUBPROJECT,
|
275
|
+
use_external_memory_version=self._use_external_memory_version,
|
276
|
+
batch_size=self._batch_size,
|
266
277
|
)
|
267
278
|
self._sklearn_object = model_trainer.train()
|
268
279
|
self._is_fitted = True
|
@@ -533,6 +544,22 @@ class BernoulliNB(BaseTransformer):
|
|
533
544
|
# each row containing a list of values.
|
534
545
|
expected_dtype = "ARRAY"
|
535
546
|
|
547
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
548
|
+
if expected_dtype == "":
|
549
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
550
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
551
|
+
expected_dtype = "ARRAY"
|
552
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
553
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
554
|
+
expected_dtype = "ARRAY"
|
555
|
+
else:
|
556
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
557
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
558
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
559
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
560
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
561
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
562
|
+
|
536
563
|
output_df = self._batch_inference(
|
537
564
|
dataset=dataset,
|
538
565
|
inference_method="transform",
|
@@ -548,8 +575,8 @@ class BernoulliNB(BaseTransformer):
|
|
548
575
|
|
549
576
|
return output_df
|
550
577
|
|
551
|
-
@available_if(
|
552
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
578
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
579
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
553
580
|
""" Method not supported for this class.
|
554
581
|
|
555
582
|
|
@@ -562,13 +589,21 @@ class BernoulliNB(BaseTransformer):
|
|
562
589
|
Returns:
|
563
590
|
Predicted dataset.
|
564
591
|
"""
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
592
|
+
self.fit(dataset)
|
593
|
+
assert self._sklearn_object is not None
|
594
|
+
return self._sklearn_object.labels_
|
595
|
+
|
596
|
+
|
597
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
598
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
599
|
+
"""
|
600
|
+
Returns:
|
601
|
+
Transformed dataset.
|
602
|
+
"""
|
603
|
+
self.fit(dataset)
|
604
|
+
assert self._sklearn_object is not None
|
605
|
+
return self._sklearn_object.embedding_
|
606
|
+
|
572
607
|
|
573
608
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
574
609
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class CategoricalNB(BaseTransformer):
|
58
70
|
r"""Naive Bayes classifier for categorical features
|
59
71
|
For more details on this class, see [sklearn.naive_bayes.CategoricalNB]
|
@@ -156,7 +168,9 @@ class CategoricalNB(BaseTransformer):
|
|
156
168
|
self.set_label_cols(label_cols)
|
157
169
|
self.set_passthrough_cols(passthrough_cols)
|
158
170
|
self.set_drop_input_cols(drop_input_cols)
|
159
|
-
self.set_sample_weight_col(sample_weight_col)
|
171
|
+
self.set_sample_weight_col(sample_weight_col)
|
172
|
+
self._use_external_memory_version = False
|
173
|
+
self._batch_size = -1
|
160
174
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
161
175
|
|
162
176
|
self._deps = list(deps)
|
@@ -236,11 +250,6 @@ class CategoricalNB(BaseTransformer):
|
|
236
250
|
if isinstance(dataset, DataFrame):
|
237
251
|
session = dataset._session
|
238
252
|
assert session is not None # keep mypy happy
|
239
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
240
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
241
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
242
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
243
|
-
|
244
253
|
# Specify input columns so column pruning will be enforced
|
245
254
|
selected_cols = self._get_active_columns()
|
246
255
|
if len(selected_cols) > 0:
|
@@ -268,7 +277,9 @@ class CategoricalNB(BaseTransformer):
|
|
268
277
|
label_cols=self.label_cols,
|
269
278
|
sample_weight_col=self.sample_weight_col,
|
270
279
|
autogenerated=self._autogenerated,
|
271
|
-
subproject=_SUBPROJECT
|
280
|
+
subproject=_SUBPROJECT,
|
281
|
+
use_external_memory_version=self._use_external_memory_version,
|
282
|
+
batch_size=self._batch_size,
|
272
283
|
)
|
273
284
|
self._sklearn_object = model_trainer.train()
|
274
285
|
self._is_fitted = True
|
@@ -539,6 +550,22 @@ class CategoricalNB(BaseTransformer):
|
|
539
550
|
# each row containing a list of values.
|
540
551
|
expected_dtype = "ARRAY"
|
541
552
|
|
553
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
554
|
+
if expected_dtype == "":
|
555
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
556
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
557
|
+
expected_dtype = "ARRAY"
|
558
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
559
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
560
|
+
expected_dtype = "ARRAY"
|
561
|
+
else:
|
562
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
563
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
564
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
565
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
566
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
567
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
568
|
+
|
542
569
|
output_df = self._batch_inference(
|
543
570
|
dataset=dataset,
|
544
571
|
inference_method="transform",
|
@@ -554,8 +581,8 @@ class CategoricalNB(BaseTransformer):
|
|
554
581
|
|
555
582
|
return output_df
|
556
583
|
|
557
|
-
@available_if(
|
558
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
584
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
585
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
559
586
|
""" Method not supported for this class.
|
560
587
|
|
561
588
|
|
@@ -568,13 +595,21 @@ class CategoricalNB(BaseTransformer):
|
|
568
595
|
Returns:
|
569
596
|
Predicted dataset.
|
570
597
|
"""
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
598
|
+
self.fit(dataset)
|
599
|
+
assert self._sklearn_object is not None
|
600
|
+
return self._sklearn_object.labels_
|
601
|
+
|
602
|
+
|
603
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
604
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
605
|
+
"""
|
606
|
+
Returns:
|
607
|
+
Transformed dataset.
|
608
|
+
"""
|
609
|
+
self.fit(dataset)
|
610
|
+
assert self._sklearn_object is not None
|
611
|
+
return self._sklearn_object.embedding_
|
612
|
+
|
578
613
|
|
579
614
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
580
615
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class ComplementNB(BaseTransformer):
|
58
70
|
r"""The Complement Naive Bayes classifier described in Rennie et al
|
59
71
|
For more details on this class, see [sklearn.naive_bayes.ComplementNB]
|
@@ -150,7 +162,9 @@ class ComplementNB(BaseTransformer):
|
|
150
162
|
self.set_label_cols(label_cols)
|
151
163
|
self.set_passthrough_cols(passthrough_cols)
|
152
164
|
self.set_drop_input_cols(drop_input_cols)
|
153
|
-
self.set_sample_weight_col(sample_weight_col)
|
165
|
+
self.set_sample_weight_col(sample_weight_col)
|
166
|
+
self._use_external_memory_version = False
|
167
|
+
self._batch_size = -1
|
154
168
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
155
169
|
|
156
170
|
self._deps = list(deps)
|
@@ -230,11 +244,6 @@ class ComplementNB(BaseTransformer):
|
|
230
244
|
if isinstance(dataset, DataFrame):
|
231
245
|
session = dataset._session
|
232
246
|
assert session is not None # keep mypy happy
|
233
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
234
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
235
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
236
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
237
|
-
|
238
247
|
# Specify input columns so column pruning will be enforced
|
239
248
|
selected_cols = self._get_active_columns()
|
240
249
|
if len(selected_cols) > 0:
|
@@ -262,7 +271,9 @@ class ComplementNB(BaseTransformer):
|
|
262
271
|
label_cols=self.label_cols,
|
263
272
|
sample_weight_col=self.sample_weight_col,
|
264
273
|
autogenerated=self._autogenerated,
|
265
|
-
subproject=_SUBPROJECT
|
274
|
+
subproject=_SUBPROJECT,
|
275
|
+
use_external_memory_version=self._use_external_memory_version,
|
276
|
+
batch_size=self._batch_size,
|
266
277
|
)
|
267
278
|
self._sklearn_object = model_trainer.train()
|
268
279
|
self._is_fitted = True
|
@@ -533,6 +544,22 @@ class ComplementNB(BaseTransformer):
|
|
533
544
|
# each row containing a list of values.
|
534
545
|
expected_dtype = "ARRAY"
|
535
546
|
|
547
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
548
|
+
if expected_dtype == "":
|
549
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
550
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
551
|
+
expected_dtype = "ARRAY"
|
552
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
553
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
554
|
+
expected_dtype = "ARRAY"
|
555
|
+
else:
|
556
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
557
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
558
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
559
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
560
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
561
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
562
|
+
|
536
563
|
output_df = self._batch_inference(
|
537
564
|
dataset=dataset,
|
538
565
|
inference_method="transform",
|
@@ -548,8 +575,8 @@ class ComplementNB(BaseTransformer):
|
|
548
575
|
|
549
576
|
return output_df
|
550
577
|
|
551
|
-
@available_if(
|
552
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
578
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
579
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
553
580
|
""" Method not supported for this class.
|
554
581
|
|
555
582
|
|
@@ -562,13 +589,21 @@ class ComplementNB(BaseTransformer):
|
|
562
589
|
Returns:
|
563
590
|
Predicted dataset.
|
564
591
|
"""
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
592
|
+
self.fit(dataset)
|
593
|
+
assert self._sklearn_object is not None
|
594
|
+
return self._sklearn_object.labels_
|
595
|
+
|
596
|
+
|
597
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
598
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
599
|
+
"""
|
600
|
+
Returns:
|
601
|
+
Transformed dataset.
|
602
|
+
"""
|
603
|
+
self.fit(dataset)
|
604
|
+
assert self._sklearn_object is not None
|
605
|
+
return self._sklearn_object.embedding_
|
606
|
+
|
572
607
|
|
573
608
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
574
609
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GaussianNB(BaseTransformer):
|
58
70
|
r"""Gaussian Naive Bayes (GaussianNB)
|
59
71
|
For more details on this class, see [sklearn.naive_bayes.GaussianNB]
|
@@ -134,7 +146,9 @@ class GaussianNB(BaseTransformer):
|
|
134
146
|
self.set_label_cols(label_cols)
|
135
147
|
self.set_passthrough_cols(passthrough_cols)
|
136
148
|
self.set_drop_input_cols(drop_input_cols)
|
137
|
-
self.set_sample_weight_col(sample_weight_col)
|
149
|
+
self.set_sample_weight_col(sample_weight_col)
|
150
|
+
self._use_external_memory_version = False
|
151
|
+
self._batch_size = -1
|
138
152
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
139
153
|
|
140
154
|
self._deps = list(deps)
|
@@ -211,11 +225,6 @@ class GaussianNB(BaseTransformer):
|
|
211
225
|
if isinstance(dataset, DataFrame):
|
212
226
|
session = dataset._session
|
213
227
|
assert session is not None # keep mypy happy
|
214
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
215
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
216
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
217
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
218
|
-
|
219
228
|
# Specify input columns so column pruning will be enforced
|
220
229
|
selected_cols = self._get_active_columns()
|
221
230
|
if len(selected_cols) > 0:
|
@@ -243,7 +252,9 @@ class GaussianNB(BaseTransformer):
|
|
243
252
|
label_cols=self.label_cols,
|
244
253
|
sample_weight_col=self.sample_weight_col,
|
245
254
|
autogenerated=self._autogenerated,
|
246
|
-
subproject=_SUBPROJECT
|
255
|
+
subproject=_SUBPROJECT,
|
256
|
+
use_external_memory_version=self._use_external_memory_version,
|
257
|
+
batch_size=self._batch_size,
|
247
258
|
)
|
248
259
|
self._sklearn_object = model_trainer.train()
|
249
260
|
self._is_fitted = True
|
@@ -514,6 +525,22 @@ class GaussianNB(BaseTransformer):
|
|
514
525
|
# each row containing a list of values.
|
515
526
|
expected_dtype = "ARRAY"
|
516
527
|
|
528
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
529
|
+
if expected_dtype == "":
|
530
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
531
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
532
|
+
expected_dtype = "ARRAY"
|
533
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
534
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
535
|
+
expected_dtype = "ARRAY"
|
536
|
+
else:
|
537
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
538
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
539
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
540
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
541
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
542
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
543
|
+
|
517
544
|
output_df = self._batch_inference(
|
518
545
|
dataset=dataset,
|
519
546
|
inference_method="transform",
|
@@ -529,8 +556,8 @@ class GaussianNB(BaseTransformer):
|
|
529
556
|
|
530
557
|
return output_df
|
531
558
|
|
532
|
-
@available_if(
|
533
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
559
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
560
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
534
561
|
""" Method not supported for this class.
|
535
562
|
|
536
563
|
|
@@ -543,13 +570,21 @@ class GaussianNB(BaseTransformer):
|
|
543
570
|
Returns:
|
544
571
|
Predicted dataset.
|
545
572
|
"""
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
573
|
+
self.fit(dataset)
|
574
|
+
assert self._sklearn_object is not None
|
575
|
+
return self._sklearn_object.labels_
|
576
|
+
|
577
|
+
|
578
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
579
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
580
|
+
"""
|
581
|
+
Returns:
|
582
|
+
Transformed dataset.
|
583
|
+
"""
|
584
|
+
self.fit(dataset)
|
585
|
+
assert self._sklearn_object is not None
|
586
|
+
return self._sklearn_object.embedding_
|
587
|
+
|
553
588
|
|
554
589
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
555
590
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.naive_bayes".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class MultinomialNB(BaseTransformer):
|
58
70
|
r"""Naive Bayes classifier for multinomial models
|
59
71
|
For more details on this class, see [sklearn.naive_bayes.MultinomialNB]
|
@@ -145,7 +157,9 @@ class MultinomialNB(BaseTransformer):
|
|
145
157
|
self.set_label_cols(label_cols)
|
146
158
|
self.set_passthrough_cols(passthrough_cols)
|
147
159
|
self.set_drop_input_cols(drop_input_cols)
|
148
|
-
self.set_sample_weight_col(sample_weight_col)
|
160
|
+
self.set_sample_weight_col(sample_weight_col)
|
161
|
+
self._use_external_memory_version = False
|
162
|
+
self._batch_size = -1
|
149
163
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
150
164
|
|
151
165
|
self._deps = list(deps)
|
@@ -224,11 +238,6 @@ class MultinomialNB(BaseTransformer):
|
|
224
238
|
if isinstance(dataset, DataFrame):
|
225
239
|
session = dataset._session
|
226
240
|
assert session is not None # keep mypy happy
|
227
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
228
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
229
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
230
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
231
|
-
|
232
241
|
# Specify input columns so column pruning will be enforced
|
233
242
|
selected_cols = self._get_active_columns()
|
234
243
|
if len(selected_cols) > 0:
|
@@ -256,7 +265,9 @@ class MultinomialNB(BaseTransformer):
|
|
256
265
|
label_cols=self.label_cols,
|
257
266
|
sample_weight_col=self.sample_weight_col,
|
258
267
|
autogenerated=self._autogenerated,
|
259
|
-
subproject=_SUBPROJECT
|
268
|
+
subproject=_SUBPROJECT,
|
269
|
+
use_external_memory_version=self._use_external_memory_version,
|
270
|
+
batch_size=self._batch_size,
|
260
271
|
)
|
261
272
|
self._sklearn_object = model_trainer.train()
|
262
273
|
self._is_fitted = True
|
@@ -527,6 +538,22 @@ class MultinomialNB(BaseTransformer):
|
|
527
538
|
# each row containing a list of values.
|
528
539
|
expected_dtype = "ARRAY"
|
529
540
|
|
541
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
542
|
+
if expected_dtype == "":
|
543
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
544
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
545
|
+
expected_dtype = "ARRAY"
|
546
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
547
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
548
|
+
expected_dtype = "ARRAY"
|
549
|
+
else:
|
550
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
551
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
552
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
553
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
554
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
555
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
556
|
+
|
530
557
|
output_df = self._batch_inference(
|
531
558
|
dataset=dataset,
|
532
559
|
inference_method="transform",
|
@@ -542,8 +569,8 @@ class MultinomialNB(BaseTransformer):
|
|
542
569
|
|
543
570
|
return output_df
|
544
571
|
|
545
|
-
@available_if(
|
546
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
572
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
573
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
547
574
|
""" Method not supported for this class.
|
548
575
|
|
549
576
|
|
@@ -556,13 +583,21 @@ class MultinomialNB(BaseTransformer):
|
|
556
583
|
Returns:
|
557
584
|
Predicted dataset.
|
558
585
|
"""
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
586
|
+
self.fit(dataset)
|
587
|
+
assert self._sklearn_object is not None
|
588
|
+
return self._sklearn_object.labels_
|
589
|
+
|
590
|
+
|
591
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
592
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
593
|
+
"""
|
594
|
+
Returns:
|
595
|
+
Transformed dataset.
|
596
|
+
"""
|
597
|
+
self.fit(dataset)
|
598
|
+
assert self._sklearn_object is not None
|
599
|
+
return self._sklearn_object.embedding_
|
600
|
+
|
566
601
|
|
567
602
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
568
603
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|