snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class NearestCentroid(BaseTransformer):
|
71
64
|
r"""Nearest centroid classifier
|
72
65
|
For more details on this class, see [sklearn.neighbors.NearestCentroid]
|
@@ -213,12 +206,7 @@ class NearestCentroid(BaseTransformer):
|
|
213
206
|
)
|
214
207
|
return selected_cols
|
215
208
|
|
216
|
-
|
217
|
-
project=_PROJECT,
|
218
|
-
subproject=_SUBPROJECT,
|
219
|
-
custom_tags=dict([("autogen", True)]),
|
220
|
-
)
|
221
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "NearestCentroid":
|
209
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "NearestCentroid":
|
222
210
|
"""Fit the NearestCentroid model according to the given training data
|
223
211
|
For more details on this function, see [sklearn.neighbors.NearestCentroid.fit]
|
224
212
|
(https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestCentroid.html#sklearn.neighbors.NearestCentroid.fit)
|
@@ -245,12 +233,14 @@ class NearestCentroid(BaseTransformer):
|
|
245
233
|
|
246
234
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
247
235
|
|
248
|
-
|
236
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
249
237
|
if SNOWML_SPROC_ENV in os.environ:
|
250
238
|
statement_params = telemetry.get_function_usage_statement_params(
|
251
239
|
project=_PROJECT,
|
252
240
|
subproject=_SUBPROJECT,
|
253
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
241
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
242
|
+
inspect.currentframe(), NearestCentroid.__class__.__name__
|
243
|
+
),
|
254
244
|
api_calls=[Session.call],
|
255
245
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
256
246
|
)
|
@@ -271,27 +261,24 @@ class NearestCentroid(BaseTransformer):
|
|
271
261
|
)
|
272
262
|
self._sklearn_object = model_trainer.train()
|
273
263
|
self._is_fitted = True
|
274
|
-
self.
|
264
|
+
self._generate_model_signatures(dataset)
|
275
265
|
return self
|
276
266
|
|
277
267
|
def _batch_inference_validate_snowpark(
|
278
268
|
self,
|
279
269
|
dataset: DataFrame,
|
280
270
|
inference_method: str,
|
281
|
-
) ->
|
282
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
283
|
-
return the available package that exists in the snowflake anaconda channel
|
271
|
+
) -> None:
|
272
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
284
273
|
|
285
274
|
Args:
|
286
275
|
dataset: snowpark dataframe
|
287
276
|
inference_method: the inference method such as predict, score...
|
288
|
-
|
277
|
+
|
289
278
|
Raises:
|
290
279
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
291
280
|
SnowflakeMLException: If the session is None, raise error
|
292
281
|
|
293
|
-
Returns:
|
294
|
-
A list of available package that exists in the snowflake anaconda channel
|
295
282
|
"""
|
296
283
|
if not self._is_fitted:
|
297
284
|
raise exceptions.SnowflakeMLException(
|
@@ -309,9 +296,7 @@ class NearestCentroid(BaseTransformer):
|
|
309
296
|
"Session must not specified for snowpark dataset."
|
310
297
|
),
|
311
298
|
)
|
312
|
-
|
313
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
314
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
299
|
+
|
315
300
|
|
316
301
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
317
302
|
@telemetry.send_api_usage_telemetry(
|
@@ -347,7 +332,9 @@ class NearestCentroid(BaseTransformer):
|
|
347
332
|
# when it is classifier, infer the datatype from label columns
|
348
333
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
349
334
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
350
|
-
label_cols_signatures = [
|
335
|
+
label_cols_signatures = [
|
336
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
337
|
+
]
|
351
338
|
if len(label_cols_signatures) == 0:
|
352
339
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
353
340
|
raise exceptions.SnowflakeMLException(
|
@@ -355,25 +342,23 @@ class NearestCentroid(BaseTransformer):
|
|
355
342
|
original_exception=ValueError(error_str),
|
356
343
|
)
|
357
344
|
|
358
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
359
|
-
label_cols_signatures[0].as_snowpark_type()
|
360
|
-
)
|
345
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
361
346
|
|
362
|
-
self.
|
363
|
-
|
347
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
348
|
+
self._deps = self._get_dependencies()
|
349
|
+
assert isinstance(
|
350
|
+
dataset._session, Session
|
351
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
364
352
|
|
365
353
|
transform_kwargs = dict(
|
366
|
-
session
|
367
|
-
dependencies
|
368
|
-
drop_input_cols
|
369
|
-
expected_output_cols_type
|
354
|
+
session=dataset._session,
|
355
|
+
dependencies=self._deps,
|
356
|
+
drop_input_cols=self._drop_input_cols,
|
357
|
+
expected_output_cols_type=expected_type_inferred,
|
370
358
|
)
|
371
359
|
|
372
360
|
elif isinstance(dataset, pd.DataFrame):
|
373
|
-
transform_kwargs = dict(
|
374
|
-
snowpark_input_cols = self._snowpark_cols,
|
375
|
-
drop_input_cols = self._drop_input_cols
|
376
|
-
)
|
361
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
377
362
|
|
378
363
|
transform_handlers = ModelTransformerBuilder.build(
|
379
364
|
dataset=dataset,
|
@@ -413,7 +398,7 @@ class NearestCentroid(BaseTransformer):
|
|
413
398
|
Transformed dataset.
|
414
399
|
"""
|
415
400
|
super()._check_dataset_type(dataset)
|
416
|
-
inference_method="transform"
|
401
|
+
inference_method = "transform"
|
417
402
|
|
418
403
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
419
404
|
# are specific to the type of dataset used.
|
@@ -443,24 +428,19 @@ class NearestCentroid(BaseTransformer):
|
|
443
428
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
444
429
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
445
430
|
|
446
|
-
self.
|
447
|
-
|
448
|
-
inference_method=inference_method,
|
449
|
-
)
|
431
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
432
|
+
self._deps = self._get_dependencies()
|
450
433
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
451
434
|
|
452
435
|
transform_kwargs = dict(
|
453
|
-
session
|
454
|
-
dependencies
|
455
|
-
drop_input_cols
|
456
|
-
expected_output_cols_type
|
436
|
+
session=dataset._session,
|
437
|
+
dependencies=self._deps,
|
438
|
+
drop_input_cols=self._drop_input_cols,
|
439
|
+
expected_output_cols_type=expected_dtype,
|
457
440
|
)
|
458
441
|
|
459
442
|
elif isinstance(dataset, pd.DataFrame):
|
460
|
-
transform_kwargs = dict(
|
461
|
-
snowpark_input_cols = self._snowpark_cols,
|
462
|
-
drop_input_cols = self._drop_input_cols
|
463
|
-
)
|
443
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
464
444
|
|
465
445
|
transform_handlers = ModelTransformerBuilder.build(
|
466
446
|
dataset=dataset,
|
@@ -479,7 +459,11 @@ class NearestCentroid(BaseTransformer):
|
|
479
459
|
return output_df
|
480
460
|
|
481
461
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
482
|
-
def fit_predict(
|
462
|
+
def fit_predict(
|
463
|
+
self,
|
464
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
465
|
+
output_cols_prefix: str = "fit_predict_",
|
466
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
483
467
|
""" Method not supported for this class.
|
484
468
|
|
485
469
|
|
@@ -504,22 +488,104 @@ class NearestCentroid(BaseTransformer):
|
|
504
488
|
)
|
505
489
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
506
490
|
drop_input_cols=self._drop_input_cols,
|
507
|
-
expected_output_cols_list=
|
491
|
+
expected_output_cols_list=(
|
492
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
493
|
+
),
|
508
494
|
)
|
509
495
|
self._sklearn_object = fitted_estimator
|
510
496
|
self._is_fitted = True
|
511
497
|
return output_result
|
512
498
|
|
499
|
+
|
500
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
501
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
502
|
+
""" Method not supported for this class.
|
503
|
+
|
513
504
|
|
514
|
-
|
515
|
-
|
516
|
-
|
505
|
+
Raises:
|
506
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
507
|
+
|
508
|
+
Args:
|
509
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
510
|
+
Snowpark or Pandas DataFrame.
|
511
|
+
output_cols_prefix: Prefix for the response columns
|
517
512
|
Returns:
|
518
513
|
Transformed dataset.
|
519
514
|
"""
|
520
|
-
self.
|
521
|
-
|
522
|
-
|
515
|
+
self._infer_input_output_cols(dataset)
|
516
|
+
super()._check_dataset_type(dataset)
|
517
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
518
|
+
estimator=self._sklearn_object,
|
519
|
+
dataset=dataset,
|
520
|
+
input_cols=self.input_cols,
|
521
|
+
label_cols=self.label_cols,
|
522
|
+
sample_weight_col=self.sample_weight_col,
|
523
|
+
autogenerated=self._autogenerated,
|
524
|
+
subproject=_SUBPROJECT,
|
525
|
+
)
|
526
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
527
|
+
drop_input_cols=self._drop_input_cols,
|
528
|
+
expected_output_cols_list=self.output_cols,
|
529
|
+
)
|
530
|
+
self._sklearn_object = fitted_estimator
|
531
|
+
self._is_fitted = True
|
532
|
+
return output_result
|
533
|
+
|
534
|
+
|
535
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
536
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
537
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
538
|
+
"""
|
539
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
540
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
541
|
+
if output_cols:
|
542
|
+
output_cols = [
|
543
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
544
|
+
for c in output_cols
|
545
|
+
]
|
546
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
547
|
+
output_cols = [output_cols_prefix]
|
548
|
+
elif self._sklearn_object is not None:
|
549
|
+
classes = self._sklearn_object.classes_
|
550
|
+
if isinstance(classes, numpy.ndarray):
|
551
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
552
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
553
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
554
|
+
output_cols = []
|
555
|
+
for i, cl in enumerate(classes):
|
556
|
+
# For binary classification, there is only one output column for each class
|
557
|
+
# ndarray as the two classes are complementary.
|
558
|
+
if len(cl) == 2:
|
559
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
560
|
+
else:
|
561
|
+
output_cols.extend([
|
562
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
563
|
+
])
|
564
|
+
else:
|
565
|
+
output_cols = []
|
566
|
+
|
567
|
+
# Make sure column names are valid snowflake identifiers.
|
568
|
+
assert output_cols is not None # Make MyPy happy
|
569
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
570
|
+
|
571
|
+
return rv
|
572
|
+
|
573
|
+
def _align_expected_output_names(
|
574
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
575
|
+
) -> List[str]:
|
576
|
+
# in case the inferred output column names dimension is different
|
577
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
578
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
579
|
+
output_df_columns = list(output_df_pd.columns)
|
580
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
581
|
+
if self.sample_weight_col:
|
582
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
583
|
+
# if the dimension of inferred output column names is correct; use it
|
584
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
585
|
+
return expected_output_cols_list
|
586
|
+
# otherwise, use the sklearn estimator's output
|
587
|
+
else:
|
588
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
523
589
|
|
524
590
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
525
591
|
@telemetry.send_api_usage_telemetry(
|
@@ -551,24 +617,26 @@ class NearestCentroid(BaseTransformer):
|
|
551
617
|
# are specific to the type of dataset used.
|
552
618
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
553
619
|
|
620
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
621
|
+
|
554
622
|
if isinstance(dataset, DataFrame):
|
555
|
-
self.
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
623
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
624
|
+
self._deps = self._get_dependencies()
|
625
|
+
assert isinstance(
|
626
|
+
dataset._session, Session
|
627
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
560
628
|
transform_kwargs = dict(
|
561
629
|
session=dataset._session,
|
562
630
|
dependencies=self._deps,
|
563
|
-
drop_input_cols
|
631
|
+
drop_input_cols=self._drop_input_cols,
|
564
632
|
expected_output_cols_type="float",
|
565
633
|
)
|
634
|
+
expected_output_cols = self._align_expected_output_names(
|
635
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
636
|
+
)
|
566
637
|
|
567
638
|
elif isinstance(dataset, pd.DataFrame):
|
568
|
-
transform_kwargs = dict(
|
569
|
-
snowpark_input_cols = self._snowpark_cols,
|
570
|
-
drop_input_cols = self._drop_input_cols
|
571
|
-
)
|
639
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
572
640
|
|
573
641
|
transform_handlers = ModelTransformerBuilder.build(
|
574
642
|
dataset=dataset,
|
@@ -580,7 +648,7 @@ class NearestCentroid(BaseTransformer):
|
|
580
648
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
581
649
|
inference_method=inference_method,
|
582
650
|
input_cols=self.input_cols,
|
583
|
-
expected_output_cols=
|
651
|
+
expected_output_cols=expected_output_cols,
|
584
652
|
**transform_kwargs
|
585
653
|
)
|
586
654
|
return output_df
|
@@ -610,29 +678,30 @@ class NearestCentroid(BaseTransformer):
|
|
610
678
|
Output dataset with log probability of the sample for each class in the model.
|
611
679
|
"""
|
612
680
|
super()._check_dataset_type(dataset)
|
613
|
-
inference_method="predict_log_proba"
|
681
|
+
inference_method = "predict_log_proba"
|
682
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
614
683
|
|
615
684
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
616
685
|
# are specific to the type of dataset used.
|
617
686
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
618
687
|
|
619
688
|
if isinstance(dataset, DataFrame):
|
620
|
-
self.
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
689
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
690
|
+
self._deps = self._get_dependencies()
|
691
|
+
assert isinstance(
|
692
|
+
dataset._session, Session
|
693
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
625
694
|
transform_kwargs = dict(
|
626
695
|
session=dataset._session,
|
627
696
|
dependencies=self._deps,
|
628
|
-
drop_input_cols
|
697
|
+
drop_input_cols=self._drop_input_cols,
|
629
698
|
expected_output_cols_type="float",
|
630
699
|
)
|
700
|
+
expected_output_cols = self._align_expected_output_names(
|
701
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
702
|
+
)
|
631
703
|
elif isinstance(dataset, pd.DataFrame):
|
632
|
-
transform_kwargs = dict(
|
633
|
-
snowpark_input_cols = self._snowpark_cols,
|
634
|
-
drop_input_cols = self._drop_input_cols
|
635
|
-
)
|
704
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
636
705
|
|
637
706
|
transform_handlers = ModelTransformerBuilder.build(
|
638
707
|
dataset=dataset,
|
@@ -645,7 +714,7 @@ class NearestCentroid(BaseTransformer):
|
|
645
714
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
646
715
|
inference_method=inference_method,
|
647
716
|
input_cols=self.input_cols,
|
648
|
-
expected_output_cols=
|
717
|
+
expected_output_cols=expected_output_cols,
|
649
718
|
**transform_kwargs
|
650
719
|
)
|
651
720
|
return output_df
|
@@ -671,30 +740,32 @@ class NearestCentroid(BaseTransformer):
|
|
671
740
|
Output dataset with results of the decision function for the samples in input dataset.
|
672
741
|
"""
|
673
742
|
super()._check_dataset_type(dataset)
|
674
|
-
inference_method="decision_function"
|
743
|
+
inference_method = "decision_function"
|
675
744
|
|
676
745
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
677
746
|
# are specific to the type of dataset used.
|
678
747
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
679
748
|
|
749
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
750
|
+
|
680
751
|
if isinstance(dataset, DataFrame):
|
681
|
-
self.
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
752
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
753
|
+
self._deps = self._get_dependencies()
|
754
|
+
assert isinstance(
|
755
|
+
dataset._session, Session
|
756
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
686
757
|
transform_kwargs = dict(
|
687
758
|
session=dataset._session,
|
688
759
|
dependencies=self._deps,
|
689
|
-
drop_input_cols
|
760
|
+
drop_input_cols=self._drop_input_cols,
|
690
761
|
expected_output_cols_type="float",
|
691
762
|
)
|
763
|
+
expected_output_cols = self._align_expected_output_names(
|
764
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
765
|
+
)
|
692
766
|
|
693
767
|
elif isinstance(dataset, pd.DataFrame):
|
694
|
-
transform_kwargs = dict(
|
695
|
-
snowpark_input_cols = self._snowpark_cols,
|
696
|
-
drop_input_cols = self._drop_input_cols
|
697
|
-
)
|
768
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
698
769
|
|
699
770
|
transform_handlers = ModelTransformerBuilder.build(
|
700
771
|
dataset=dataset,
|
@@ -707,7 +778,7 @@ class NearestCentroid(BaseTransformer):
|
|
707
778
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
708
779
|
inference_method=inference_method,
|
709
780
|
input_cols=self.input_cols,
|
710
|
-
expected_output_cols=
|
781
|
+
expected_output_cols=expected_output_cols,
|
711
782
|
**transform_kwargs
|
712
783
|
)
|
713
784
|
return output_df
|
@@ -736,17 +807,17 @@ class NearestCentroid(BaseTransformer):
|
|
736
807
|
Output dataset with probability of the sample for each class in the model.
|
737
808
|
"""
|
738
809
|
super()._check_dataset_type(dataset)
|
739
|
-
inference_method="score_samples"
|
810
|
+
inference_method = "score_samples"
|
740
811
|
|
741
812
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
742
813
|
# are specific to the type of dataset used.
|
743
814
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
744
815
|
|
816
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
817
|
+
|
745
818
|
if isinstance(dataset, DataFrame):
|
746
|
-
self.
|
747
|
-
|
748
|
-
inference_method=inference_method,
|
749
|
-
)
|
819
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
820
|
+
self._deps = self._get_dependencies()
|
750
821
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
751
822
|
transform_kwargs = dict(
|
752
823
|
session=dataset._session,
|
@@ -754,6 +825,9 @@ class NearestCentroid(BaseTransformer):
|
|
754
825
|
drop_input_cols = self._drop_input_cols,
|
755
826
|
expected_output_cols_type="float",
|
756
827
|
)
|
828
|
+
expected_output_cols = self._align_expected_output_names(
|
829
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
830
|
+
)
|
757
831
|
|
758
832
|
elif isinstance(dataset, pd.DataFrame):
|
759
833
|
transform_kwargs = dict(
|
@@ -772,7 +846,7 @@ class NearestCentroid(BaseTransformer):
|
|
772
846
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
773
847
|
inference_method=inference_method,
|
774
848
|
input_cols=self.input_cols,
|
775
|
-
expected_output_cols=
|
849
|
+
expected_output_cols=expected_output_cols,
|
776
850
|
**transform_kwargs
|
777
851
|
)
|
778
852
|
return output_df
|
@@ -807,17 +881,15 @@ class NearestCentroid(BaseTransformer):
|
|
807
881
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
808
882
|
|
809
883
|
if isinstance(dataset, DataFrame):
|
810
|
-
self.
|
811
|
-
|
812
|
-
inference_method="score",
|
813
|
-
)
|
884
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
885
|
+
self._deps = self._get_dependencies()
|
814
886
|
selected_cols = self._get_active_columns()
|
815
887
|
if len(selected_cols) > 0:
|
816
888
|
dataset = dataset.select(selected_cols)
|
817
889
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
818
890
|
transform_kwargs = dict(
|
819
891
|
session=dataset._session,
|
820
|
-
dependencies=
|
892
|
+
dependencies=self._deps,
|
821
893
|
score_sproc_imports=['sklearn'],
|
822
894
|
)
|
823
895
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -882,11 +954,8 @@ class NearestCentroid(BaseTransformer):
|
|
882
954
|
|
883
955
|
if isinstance(dataset, DataFrame):
|
884
956
|
|
885
|
-
self.
|
886
|
-
|
887
|
-
inference_method=inference_method,
|
888
|
-
|
889
|
-
)
|
957
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
958
|
+
self._deps = self._get_dependencies()
|
890
959
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
891
960
|
transform_kwargs = dict(
|
892
961
|
session = dataset._session,
|
@@ -919,50 +988,84 @@ class NearestCentroid(BaseTransformer):
|
|
919
988
|
)
|
920
989
|
return output_df
|
921
990
|
|
991
|
+
|
992
|
+
|
993
|
+
def to_sklearn(self) -> Any:
|
994
|
+
"""Get sklearn.neighbors.NearestCentroid object.
|
995
|
+
"""
|
996
|
+
if self._sklearn_object is None:
|
997
|
+
self._sklearn_object = self._create_sklearn_object()
|
998
|
+
return self._sklearn_object
|
999
|
+
|
1000
|
+
def to_xgboost(self) -> Any:
|
1001
|
+
raise exceptions.SnowflakeMLException(
|
1002
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1003
|
+
original_exception=AttributeError(
|
1004
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1005
|
+
"to_xgboost()",
|
1006
|
+
"to_sklearn()"
|
1007
|
+
)
|
1008
|
+
),
|
1009
|
+
)
|
1010
|
+
|
1011
|
+
def to_lightgbm(self) -> Any:
|
1012
|
+
raise exceptions.SnowflakeMLException(
|
1013
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1014
|
+
original_exception=AttributeError(
|
1015
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1016
|
+
"to_lightgbm()",
|
1017
|
+
"to_sklearn()"
|
1018
|
+
)
|
1019
|
+
),
|
1020
|
+
)
|
1021
|
+
|
1022
|
+
def _get_dependencies(self) -> List[str]:
|
1023
|
+
return self._deps
|
1024
|
+
|
922
1025
|
|
923
|
-
def
|
1026
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
924
1027
|
self._model_signature_dict = dict()
|
925
1028
|
|
926
1029
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
927
1030
|
|
928
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1031
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
929
1032
|
outputs: List[BaseFeatureSpec] = []
|
930
1033
|
if hasattr(self, "predict"):
|
931
1034
|
# keep mypy happy
|
932
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1035
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
933
1036
|
# For classifier, the type of predict is the same as the type of label
|
934
|
-
if self._sklearn_object._estimator_type ==
|
935
|
-
|
1037
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1038
|
+
# label columns is the desired type for output
|
936
1039
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
937
1040
|
# rename the output columns
|
938
1041
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
939
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
940
|
-
|
941
|
-
|
1042
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1043
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1044
|
+
)
|
942
1045
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
943
1046
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
944
|
-
# Clusterer returns int64 cluster labels.
|
1047
|
+
# Clusterer returns int64 cluster labels.
|
945
1048
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
946
1049
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
947
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
948
|
-
|
949
|
-
|
950
|
-
|
1050
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1051
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1052
|
+
)
|
1053
|
+
|
951
1054
|
# For regressor, the type of predict is float64
|
952
|
-
elif self._sklearn_object._estimator_type ==
|
1055
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
953
1056
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
954
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
955
|
-
|
956
|
-
|
957
|
-
|
1057
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1058
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1059
|
+
)
|
1060
|
+
|
958
1061
|
for prob_func in PROB_FUNCTIONS:
|
959
1062
|
if hasattr(self, prob_func):
|
960
1063
|
output_cols_prefix: str = f"{prob_func}_"
|
961
1064
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
962
1065
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
963
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
964
|
-
|
965
|
-
|
1066
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1067
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1068
|
+
)
|
966
1069
|
|
967
1070
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
968
1071
|
items = list(self._model_signature_dict.items())
|
@@ -975,10 +1078,10 @@ class NearestCentroid(BaseTransformer):
|
|
975
1078
|
"""Returns model signature of current class.
|
976
1079
|
|
977
1080
|
Raises:
|
978
|
-
|
1081
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
979
1082
|
|
980
1083
|
Returns:
|
981
|
-
Dict
|
1084
|
+
Dict with each method and its input output signature
|
982
1085
|
"""
|
983
1086
|
if self._model_signature_dict is None:
|
984
1087
|
raise exceptions.SnowflakeMLException(
|
@@ -986,35 +1089,3 @@ class NearestCentroid(BaseTransformer):
|
|
986
1089
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
987
1090
|
)
|
988
1091
|
return self._model_signature_dict
|
989
|
-
|
990
|
-
def to_sklearn(self) -> Any:
|
991
|
-
"""Get sklearn.neighbors.NearestCentroid object.
|
992
|
-
"""
|
993
|
-
if self._sklearn_object is None:
|
994
|
-
self._sklearn_object = self._create_sklearn_object()
|
995
|
-
return self._sklearn_object
|
996
|
-
|
997
|
-
def to_xgboost(self) -> Any:
|
998
|
-
raise exceptions.SnowflakeMLException(
|
999
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1000
|
-
original_exception=AttributeError(
|
1001
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1002
|
-
"to_xgboost()",
|
1003
|
-
"to_sklearn()"
|
1004
|
-
)
|
1005
|
-
),
|
1006
|
-
)
|
1007
|
-
|
1008
|
-
def to_lightgbm(self) -> Any:
|
1009
|
-
raise exceptions.SnowflakeMLException(
|
1010
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1011
|
-
original_exception=AttributeError(
|
1012
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1013
|
-
"to_lightgbm()",
|
1014
|
-
"to_sklearn()"
|
1015
|
-
)
|
1016
|
-
),
|
1017
|
-
)
|
1018
|
-
|
1019
|
-
def _get_dependencies(self) -> List[str]:
|
1020
|
-
return self._deps
|