snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -122,7 +122,8 @@ def precision_recall_curve(
|
|
122
122
|
result_module = cloudpickle.loads(pickled_result_module)
|
123
123
|
return result_module.serialize(session, (precision, recall, thresholds)) # type: ignore[no-any-return]
|
124
124
|
|
125
|
-
|
125
|
+
kwargs = telemetry.get_sproc_statement_params_kwargs(precision_recall_curve_anon_sproc, statement_params)
|
126
|
+
result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session, **kwargs))
|
126
127
|
res: Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
|
127
128
|
return res
|
128
129
|
|
@@ -271,7 +272,8 @@ def roc_auc_score(
|
|
271
272
|
result_module = cloudpickle.loads(pickled_result_module)
|
272
273
|
return result_module.serialize(session, auc) # type: ignore[no-any-return]
|
273
274
|
|
274
|
-
|
275
|
+
kwargs = telemetry.get_sproc_statement_params_kwargs(roc_auc_score_anon_sproc, statement_params)
|
276
|
+
result_object = result.deserialize(session, roc_auc_score_anon_sproc(session, **kwargs))
|
275
277
|
auc: Union[float, npt.NDArray[np.float_]] = result_object
|
276
278
|
return auc
|
277
279
|
|
@@ -372,7 +374,9 @@ def roc_curve(
|
|
372
374
|
result_module = cloudpickle.loads(pickled_result_module)
|
373
375
|
return result_module.serialize(session, (fpr, tpr, thresholds)) # type: ignore[no-any-return]
|
374
376
|
|
375
|
-
|
377
|
+
kwargs = telemetry.get_sproc_statement_params_kwargs(roc_curve_anon_sproc, statement_params)
|
378
|
+
result_object = result.deserialize(session, roc_curve_anon_sproc(session, **kwargs))
|
379
|
+
|
376
380
|
res: Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
|
377
381
|
|
378
382
|
return res
|
@@ -108,7 +108,8 @@ def d2_absolute_error_score(
|
|
108
108
|
result_module = cloudpickle.loads(pickled_snowflake_result)
|
109
109
|
return result_module.serialize(session, score) # type: ignore[no-any-return]
|
110
110
|
|
111
|
-
|
111
|
+
kwargs = telemetry.get_sproc_statement_params_kwargs(d2_absolute_error_score_anon_sproc, statement_params)
|
112
|
+
result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session, **kwargs))
|
112
113
|
score: Union[float, npt.NDArray[np.float_]] = result_object
|
113
114
|
return score
|
114
115
|
|
@@ -205,7 +206,8 @@ def d2_pinball_score(
|
|
205
206
|
result_module = cloudpickle.loads(pickled_result_module)
|
206
207
|
return result_module.serialize(session, score) # type: ignore[no-any-return]
|
207
208
|
|
208
|
-
|
209
|
+
kwargs = telemetry.get_sproc_statement_params_kwargs(d2_pinball_score_anon_sproc, statement_params)
|
210
|
+
result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session, **kwargs))
|
209
211
|
|
210
212
|
score: Union[float, npt.NDArray[np.float_]] = result_object
|
211
213
|
return score
|
@@ -319,7 +321,8 @@ def explained_variance_score(
|
|
319
321
|
result_module = cloudpickle.loads(pickled_result_module)
|
320
322
|
return result_module.serialize(session, score) # type: ignore[no-any-return]
|
321
323
|
|
322
|
-
|
324
|
+
kwargs = telemetry.get_sproc_statement_params_kwargs(explained_variance_score_anon_sproc, statement_params)
|
325
|
+
result_object = result.deserialize(session, explained_variance_score_anon_sproc(session, **kwargs))
|
323
326
|
score: Union[float, npt.NDArray[np.float_]] = result_object
|
324
327
|
return score
|
325
328
|
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class BayesianGaussianMixture(BaseTransformer):
|
58
70
|
r"""Variational Bayesian estimation of a Gaussian mixture
|
59
71
|
For more details on this class, see [sklearn.mixture.BayesianGaussianMixture]
|
@@ -241,7 +253,9 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
241
253
|
self.set_label_cols(label_cols)
|
242
254
|
self.set_passthrough_cols(passthrough_cols)
|
243
255
|
self.set_drop_input_cols(drop_input_cols)
|
244
|
-
self.set_sample_weight_col(sample_weight_col)
|
256
|
+
self.set_sample_weight_col(sample_weight_col)
|
257
|
+
self._use_external_memory_version = False
|
258
|
+
self._batch_size = -1
|
245
259
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
246
260
|
|
247
261
|
self._deps = list(deps)
|
@@ -333,11 +347,6 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
333
347
|
if isinstance(dataset, DataFrame):
|
334
348
|
session = dataset._session
|
335
349
|
assert session is not None # keep mypy happy
|
336
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
337
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
338
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
339
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
340
|
-
|
341
350
|
# Specify input columns so column pruning will be enforced
|
342
351
|
selected_cols = self._get_active_columns()
|
343
352
|
if len(selected_cols) > 0:
|
@@ -365,7 +374,9 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
365
374
|
label_cols=self.label_cols,
|
366
375
|
sample_weight_col=self.sample_weight_col,
|
367
376
|
autogenerated=self._autogenerated,
|
368
|
-
subproject=_SUBPROJECT
|
377
|
+
subproject=_SUBPROJECT,
|
378
|
+
use_external_memory_version=self._use_external_memory_version,
|
379
|
+
batch_size=self._batch_size,
|
369
380
|
)
|
370
381
|
self._sklearn_object = model_trainer.train()
|
371
382
|
self._is_fitted = True
|
@@ -636,6 +647,22 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
636
647
|
# each row containing a list of values.
|
637
648
|
expected_dtype = "ARRAY"
|
638
649
|
|
650
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
651
|
+
if expected_dtype == "":
|
652
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
653
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
654
|
+
expected_dtype = "ARRAY"
|
655
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
656
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
657
|
+
expected_dtype = "ARRAY"
|
658
|
+
else:
|
659
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
660
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
661
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
662
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
663
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
664
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
665
|
+
|
639
666
|
output_df = self._batch_inference(
|
640
667
|
dataset=dataset,
|
641
668
|
inference_method="transform",
|
@@ -651,8 +678,8 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
651
678
|
|
652
679
|
return output_df
|
653
680
|
|
654
|
-
@available_if(
|
655
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
681
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
682
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
656
683
|
""" Estimate model parameters using X and predict the labels for X
|
657
684
|
For more details on this function, see [sklearn.mixture.BayesianGaussianMixture.fit_predict]
|
658
685
|
(https://scikit-learn.org/stable/modules/generated/sklearn.mixture.BayesianGaussianMixture.html#sklearn.mixture.BayesianGaussianMixture.fit_predict)
|
@@ -667,13 +694,21 @@ class BayesianGaussianMixture(BaseTransformer):
|
|
667
694
|
Returns:
|
668
695
|
Predicted dataset.
|
669
696
|
"""
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
697
|
+
self.fit(dataset)
|
698
|
+
assert self._sklearn_object is not None
|
699
|
+
return self._sklearn_object.labels_
|
700
|
+
|
701
|
+
|
702
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
703
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
704
|
+
"""
|
705
|
+
Returns:
|
706
|
+
Transformed dataset.
|
707
|
+
"""
|
708
|
+
self.fit(dataset)
|
709
|
+
assert self._sklearn_object is not None
|
710
|
+
return self._sklearn_object.embedding_
|
711
|
+
|
677
712
|
|
678
713
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
679
714
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.mixture".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class GaussianMixture(BaseTransformer):
|
58
70
|
r"""Gaussian Mixture
|
59
71
|
For more details on this class, see [sklearn.mixture.GaussianMixture]
|
@@ -217,7 +229,9 @@ class GaussianMixture(BaseTransformer):
|
|
217
229
|
self.set_label_cols(label_cols)
|
218
230
|
self.set_passthrough_cols(passthrough_cols)
|
219
231
|
self.set_drop_input_cols(drop_input_cols)
|
220
|
-
self.set_sample_weight_col(sample_weight_col)
|
232
|
+
self.set_sample_weight_col(sample_weight_col)
|
233
|
+
self._use_external_memory_version = False
|
234
|
+
self._batch_size = -1
|
221
235
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
222
236
|
|
223
237
|
self._deps = list(deps)
|
@@ -306,11 +320,6 @@ class GaussianMixture(BaseTransformer):
|
|
306
320
|
if isinstance(dataset, DataFrame):
|
307
321
|
session = dataset._session
|
308
322
|
assert session is not None # keep mypy happy
|
309
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
310
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
311
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
312
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
313
|
-
|
314
323
|
# Specify input columns so column pruning will be enforced
|
315
324
|
selected_cols = self._get_active_columns()
|
316
325
|
if len(selected_cols) > 0:
|
@@ -338,7 +347,9 @@ class GaussianMixture(BaseTransformer):
|
|
338
347
|
label_cols=self.label_cols,
|
339
348
|
sample_weight_col=self.sample_weight_col,
|
340
349
|
autogenerated=self._autogenerated,
|
341
|
-
subproject=_SUBPROJECT
|
350
|
+
subproject=_SUBPROJECT,
|
351
|
+
use_external_memory_version=self._use_external_memory_version,
|
352
|
+
batch_size=self._batch_size,
|
342
353
|
)
|
343
354
|
self._sklearn_object = model_trainer.train()
|
344
355
|
self._is_fitted = True
|
@@ -609,6 +620,22 @@ class GaussianMixture(BaseTransformer):
|
|
609
620
|
# each row containing a list of values.
|
610
621
|
expected_dtype = "ARRAY"
|
611
622
|
|
623
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
624
|
+
if expected_dtype == "":
|
625
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
626
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
627
|
+
expected_dtype = "ARRAY"
|
628
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
629
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
630
|
+
expected_dtype = "ARRAY"
|
631
|
+
else:
|
632
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
633
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
634
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
635
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
636
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
637
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
638
|
+
|
612
639
|
output_df = self._batch_inference(
|
613
640
|
dataset=dataset,
|
614
641
|
inference_method="transform",
|
@@ -624,8 +651,8 @@ class GaussianMixture(BaseTransformer):
|
|
624
651
|
|
625
652
|
return output_df
|
626
653
|
|
627
|
-
@available_if(
|
628
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
654
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
655
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
629
656
|
""" Estimate model parameters using X and predict the labels for X
|
630
657
|
For more details on this function, see [sklearn.mixture.GaussianMixture.fit_predict]
|
631
658
|
(https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html#sklearn.mixture.GaussianMixture.fit_predict)
|
@@ -640,13 +667,21 @@ class GaussianMixture(BaseTransformer):
|
|
640
667
|
Returns:
|
641
668
|
Predicted dataset.
|
642
669
|
"""
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
670
|
+
self.fit(dataset)
|
671
|
+
assert self._sklearn_object is not None
|
672
|
+
return self._sklearn_object.labels_
|
673
|
+
|
674
|
+
|
675
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
676
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
677
|
+
"""
|
678
|
+
Returns:
|
679
|
+
Transformed dataset.
|
680
|
+
"""
|
681
|
+
self.fit(dataset)
|
682
|
+
assert self._sklearn_object is not None
|
683
|
+
return self._sklearn_object.embedding_
|
684
|
+
|
650
685
|
|
651
686
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
652
687
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class OneVsOneClassifier(BaseTransformer):
|
58
70
|
r"""One-vs-one multiclass strategy
|
59
71
|
For more details on this class, see [sklearn.multiclass.OneVsOneClassifier]
|
@@ -141,7 +153,9 @@ class OneVsOneClassifier(BaseTransformer):
|
|
141
153
|
self.set_label_cols(label_cols)
|
142
154
|
self.set_passthrough_cols(passthrough_cols)
|
143
155
|
self.set_drop_input_cols(drop_input_cols)
|
144
|
-
self.set_sample_weight_col(sample_weight_col)
|
156
|
+
self.set_sample_weight_col(sample_weight_col)
|
157
|
+
self._use_external_memory_version = False
|
158
|
+
self._batch_size = -1
|
145
159
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
146
160
|
deps = deps | gather_dependencies(estimator)
|
147
161
|
self._deps = list(deps)
|
@@ -218,11 +232,6 @@ class OneVsOneClassifier(BaseTransformer):
|
|
218
232
|
if isinstance(dataset, DataFrame):
|
219
233
|
session = dataset._session
|
220
234
|
assert session is not None # keep mypy happy
|
221
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
222
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
223
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
224
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
225
|
-
|
226
235
|
# Specify input columns so column pruning will be enforced
|
227
236
|
selected_cols = self._get_active_columns()
|
228
237
|
if len(selected_cols) > 0:
|
@@ -250,7 +259,9 @@ class OneVsOneClassifier(BaseTransformer):
|
|
250
259
|
label_cols=self.label_cols,
|
251
260
|
sample_weight_col=self.sample_weight_col,
|
252
261
|
autogenerated=self._autogenerated,
|
253
|
-
subproject=_SUBPROJECT
|
262
|
+
subproject=_SUBPROJECT,
|
263
|
+
use_external_memory_version=self._use_external_memory_version,
|
264
|
+
batch_size=self._batch_size,
|
254
265
|
)
|
255
266
|
self._sklearn_object = model_trainer.train()
|
256
267
|
self._is_fitted = True
|
@@ -521,6 +532,22 @@ class OneVsOneClassifier(BaseTransformer):
|
|
521
532
|
# each row containing a list of values.
|
522
533
|
expected_dtype = "ARRAY"
|
523
534
|
|
535
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
536
|
+
if expected_dtype == "":
|
537
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
538
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
539
|
+
expected_dtype = "ARRAY"
|
540
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
541
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
542
|
+
expected_dtype = "ARRAY"
|
543
|
+
else:
|
544
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
545
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
546
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
547
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
548
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
549
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
550
|
+
|
524
551
|
output_df = self._batch_inference(
|
525
552
|
dataset=dataset,
|
526
553
|
inference_method="transform",
|
@@ -536,8 +563,8 @@ class OneVsOneClassifier(BaseTransformer):
|
|
536
563
|
|
537
564
|
return output_df
|
538
565
|
|
539
|
-
@available_if(
|
540
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
566
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
567
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
541
568
|
""" Method not supported for this class.
|
542
569
|
|
543
570
|
|
@@ -550,13 +577,21 @@ class OneVsOneClassifier(BaseTransformer):
|
|
550
577
|
Returns:
|
551
578
|
Predicted dataset.
|
552
579
|
"""
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
580
|
+
self.fit(dataset)
|
581
|
+
assert self._sklearn_object is not None
|
582
|
+
return self._sklearn_object.labels_
|
583
|
+
|
584
|
+
|
585
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
586
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
587
|
+
"""
|
588
|
+
Returns:
|
589
|
+
Transformed dataset.
|
590
|
+
"""
|
591
|
+
self.fit(dataset)
|
592
|
+
assert self._sklearn_object is not None
|
593
|
+
return self._sklearn_object.embedding_
|
594
|
+
|
560
595
|
|
561
596
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
562
597
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class OneVsRestClassifier(BaseTransformer):
|
58
70
|
r"""One-vs-the-rest (OvR) multiclass strategy
|
59
71
|
For more details on this class, see [sklearn.multiclass.OneVsRestClassifier]
|
@@ -149,7 +161,9 @@ class OneVsRestClassifier(BaseTransformer):
|
|
149
161
|
self.set_label_cols(label_cols)
|
150
162
|
self.set_passthrough_cols(passthrough_cols)
|
151
163
|
self.set_drop_input_cols(drop_input_cols)
|
152
|
-
self.set_sample_weight_col(sample_weight_col)
|
164
|
+
self.set_sample_weight_col(sample_weight_col)
|
165
|
+
self._use_external_memory_version = False
|
166
|
+
self._batch_size = -1
|
153
167
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
154
168
|
deps = deps | gather_dependencies(estimator)
|
155
169
|
self._deps = list(deps)
|
@@ -227,11 +241,6 @@ class OneVsRestClassifier(BaseTransformer):
|
|
227
241
|
if isinstance(dataset, DataFrame):
|
228
242
|
session = dataset._session
|
229
243
|
assert session is not None # keep mypy happy
|
230
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
231
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
232
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
233
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
234
|
-
|
235
244
|
# Specify input columns so column pruning will be enforced
|
236
245
|
selected_cols = self._get_active_columns()
|
237
246
|
if len(selected_cols) > 0:
|
@@ -259,7 +268,9 @@ class OneVsRestClassifier(BaseTransformer):
|
|
259
268
|
label_cols=self.label_cols,
|
260
269
|
sample_weight_col=self.sample_weight_col,
|
261
270
|
autogenerated=self._autogenerated,
|
262
|
-
subproject=_SUBPROJECT
|
271
|
+
subproject=_SUBPROJECT,
|
272
|
+
use_external_memory_version=self._use_external_memory_version,
|
273
|
+
batch_size=self._batch_size,
|
263
274
|
)
|
264
275
|
self._sklearn_object = model_trainer.train()
|
265
276
|
self._is_fitted = True
|
@@ -530,6 +541,22 @@ class OneVsRestClassifier(BaseTransformer):
|
|
530
541
|
# each row containing a list of values.
|
531
542
|
expected_dtype = "ARRAY"
|
532
543
|
|
544
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
545
|
+
if expected_dtype == "":
|
546
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
547
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
548
|
+
expected_dtype = "ARRAY"
|
549
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
550
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
551
|
+
expected_dtype = "ARRAY"
|
552
|
+
else:
|
553
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
554
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
555
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
556
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
557
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
558
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
559
|
+
|
533
560
|
output_df = self._batch_inference(
|
534
561
|
dataset=dataset,
|
535
562
|
inference_method="transform",
|
@@ -545,8 +572,8 @@ class OneVsRestClassifier(BaseTransformer):
|
|
545
572
|
|
546
573
|
return output_df
|
547
574
|
|
548
|
-
@available_if(
|
549
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
575
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
576
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
550
577
|
""" Method not supported for this class.
|
551
578
|
|
552
579
|
|
@@ -559,13 +586,21 @@ class OneVsRestClassifier(BaseTransformer):
|
|
559
586
|
Returns:
|
560
587
|
Predicted dataset.
|
561
588
|
"""
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
589
|
+
self.fit(dataset)
|
590
|
+
assert self._sklearn_object is not None
|
591
|
+
return self._sklearn_object.labels_
|
592
|
+
|
593
|
+
|
594
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
595
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
596
|
+
"""
|
597
|
+
Returns:
|
598
|
+
Transformed dataset.
|
599
|
+
"""
|
600
|
+
self.fit(dataset)
|
601
|
+
assert self._sklearn_object is not None
|
602
|
+
return self._sklearn_object.embedding_
|
603
|
+
|
569
604
|
|
570
605
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
571
606
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
|
|
54
54
|
_SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("sklearn.", "").split("_")])
|
55
55
|
|
56
56
|
|
57
|
+
def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
|
58
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
59
|
+
return False and callable(getattr(self._sklearn_object, "fit_predict", None))
|
60
|
+
return check
|
61
|
+
|
62
|
+
|
63
|
+
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
+
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
+
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
+
return check
|
67
|
+
|
68
|
+
|
57
69
|
class OutputCodeClassifier(BaseTransformer):
|
58
70
|
r"""(Error-Correcting) Output-Code multiclass strategy
|
59
71
|
For more details on this class, see [sklearn.multiclass.OutputCodeClassifier]
|
@@ -151,7 +163,9 @@ class OutputCodeClassifier(BaseTransformer):
|
|
151
163
|
self.set_label_cols(label_cols)
|
152
164
|
self.set_passthrough_cols(passthrough_cols)
|
153
165
|
self.set_drop_input_cols(drop_input_cols)
|
154
|
-
self.set_sample_weight_col(sample_weight_col)
|
166
|
+
self.set_sample_weight_col(sample_weight_col)
|
167
|
+
self._use_external_memory_version = False
|
168
|
+
self._batch_size = -1
|
155
169
|
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
|
156
170
|
deps = deps | gather_dependencies(estimator)
|
157
171
|
self._deps = list(deps)
|
@@ -230,11 +244,6 @@ class OutputCodeClassifier(BaseTransformer):
|
|
230
244
|
if isinstance(dataset, DataFrame):
|
231
245
|
session = dataset._session
|
232
246
|
assert session is not None # keep mypy happy
|
233
|
-
# Validate that key package version in user workspace are supported in snowflake conda channel
|
234
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
235
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
236
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
237
|
-
|
238
247
|
# Specify input columns so column pruning will be enforced
|
239
248
|
selected_cols = self._get_active_columns()
|
240
249
|
if len(selected_cols) > 0:
|
@@ -262,7 +271,9 @@ class OutputCodeClassifier(BaseTransformer):
|
|
262
271
|
label_cols=self.label_cols,
|
263
272
|
sample_weight_col=self.sample_weight_col,
|
264
273
|
autogenerated=self._autogenerated,
|
265
|
-
subproject=_SUBPROJECT
|
274
|
+
subproject=_SUBPROJECT,
|
275
|
+
use_external_memory_version=self._use_external_memory_version,
|
276
|
+
batch_size=self._batch_size,
|
266
277
|
)
|
267
278
|
self._sklearn_object = model_trainer.train()
|
268
279
|
self._is_fitted = True
|
@@ -533,6 +544,22 @@ class OutputCodeClassifier(BaseTransformer):
|
|
533
544
|
# each row containing a list of values.
|
534
545
|
expected_dtype = "ARRAY"
|
535
546
|
|
547
|
+
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
548
|
+
if expected_dtype == "":
|
549
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
|
550
|
+
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
551
|
+
expected_dtype = "ARRAY"
|
552
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
|
553
|
+
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
554
|
+
expected_dtype = "ARRAY"
|
555
|
+
else:
|
556
|
+
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
557
|
+
# We can only infer the output types from the input types if the following two statemetns are true:
|
558
|
+
# 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
|
559
|
+
# 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
|
560
|
+
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
561
|
+
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
562
|
+
|
536
563
|
output_df = self._batch_inference(
|
537
564
|
dataset=dataset,
|
538
565
|
inference_method="transform",
|
@@ -548,8 +575,8 @@ class OutputCodeClassifier(BaseTransformer):
|
|
548
575
|
|
549
576
|
return output_df
|
550
577
|
|
551
|
-
@available_if(
|
552
|
-
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
|
578
|
+
@available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
|
579
|
+
def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
553
580
|
""" Method not supported for this class.
|
554
581
|
|
555
582
|
|
@@ -562,13 +589,21 @@ class OutputCodeClassifier(BaseTransformer):
|
|
562
589
|
Returns:
|
563
590
|
Predicted dataset.
|
564
591
|
"""
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
592
|
+
self.fit(dataset)
|
593
|
+
assert self._sklearn_object is not None
|
594
|
+
return self._sklearn_object.labels_
|
595
|
+
|
596
|
+
|
597
|
+
@available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
|
598
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
|
599
|
+
"""
|
600
|
+
Returns:
|
601
|
+
Transformed dataset.
|
602
|
+
"""
|
603
|
+
self.fit(dataset)
|
604
|
+
assert self._sklearn_object is not None
|
605
|
+
return self._sklearn_object.embedding_
|
606
|
+
|
572
607
|
|
573
608
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
574
609
|
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|