snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class KMeans(BaseTransformer):
|
71
64
|
r"""K-Means clustering
|
72
65
|
For more details on this class, see [sklearn.cluster.KMeans]
|
@@ -277,12 +270,7 @@ class KMeans(BaseTransformer):
|
|
277
270
|
)
|
278
271
|
return selected_cols
|
279
272
|
|
280
|
-
|
281
|
-
project=_PROJECT,
|
282
|
-
subproject=_SUBPROJECT,
|
283
|
-
custom_tags=dict([("autogen", True)]),
|
284
|
-
)
|
285
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "KMeans":
|
273
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "KMeans":
|
286
274
|
"""Compute k-means clustering
|
287
275
|
For more details on this function, see [sklearn.cluster.KMeans.fit]
|
288
276
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit)
|
@@ -309,12 +297,14 @@ class KMeans(BaseTransformer):
|
|
309
297
|
|
310
298
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
311
299
|
|
312
|
-
|
300
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
313
301
|
if SNOWML_SPROC_ENV in os.environ:
|
314
302
|
statement_params = telemetry.get_function_usage_statement_params(
|
315
303
|
project=_PROJECT,
|
316
304
|
subproject=_SUBPROJECT,
|
317
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
305
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
306
|
+
inspect.currentframe(), KMeans.__class__.__name__
|
307
|
+
),
|
318
308
|
api_calls=[Session.call],
|
319
309
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
320
310
|
)
|
@@ -335,27 +325,24 @@ class KMeans(BaseTransformer):
|
|
335
325
|
)
|
336
326
|
self._sklearn_object = model_trainer.train()
|
337
327
|
self._is_fitted = True
|
338
|
-
self.
|
328
|
+
self._generate_model_signatures(dataset)
|
339
329
|
return self
|
340
330
|
|
341
331
|
def _batch_inference_validate_snowpark(
|
342
332
|
self,
|
343
333
|
dataset: DataFrame,
|
344
334
|
inference_method: str,
|
345
|
-
) ->
|
346
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
347
|
-
return the available package that exists in the snowflake anaconda channel
|
335
|
+
) -> None:
|
336
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
348
337
|
|
349
338
|
Args:
|
350
339
|
dataset: snowpark dataframe
|
351
340
|
inference_method: the inference method such as predict, score...
|
352
|
-
|
341
|
+
|
353
342
|
Raises:
|
354
343
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
355
344
|
SnowflakeMLException: If the session is None, raise error
|
356
345
|
|
357
|
-
Returns:
|
358
|
-
A list of available package that exists in the snowflake anaconda channel
|
359
346
|
"""
|
360
347
|
if not self._is_fitted:
|
361
348
|
raise exceptions.SnowflakeMLException(
|
@@ -373,9 +360,7 @@ class KMeans(BaseTransformer):
|
|
373
360
|
"Session must not specified for snowpark dataset."
|
374
361
|
),
|
375
362
|
)
|
376
|
-
|
377
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
378
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
363
|
+
|
379
364
|
|
380
365
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
381
366
|
@telemetry.send_api_usage_telemetry(
|
@@ -411,7 +396,9 @@ class KMeans(BaseTransformer):
|
|
411
396
|
# when it is classifier, infer the datatype from label columns
|
412
397
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
413
398
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
414
|
-
label_cols_signatures = [
|
399
|
+
label_cols_signatures = [
|
400
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
401
|
+
]
|
415
402
|
if len(label_cols_signatures) == 0:
|
416
403
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
417
404
|
raise exceptions.SnowflakeMLException(
|
@@ -419,25 +406,23 @@ class KMeans(BaseTransformer):
|
|
419
406
|
original_exception=ValueError(error_str),
|
420
407
|
)
|
421
408
|
|
422
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
423
|
-
label_cols_signatures[0].as_snowpark_type()
|
424
|
-
)
|
409
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
425
410
|
|
426
|
-
self.
|
427
|
-
|
411
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
412
|
+
self._deps = self._get_dependencies()
|
413
|
+
assert isinstance(
|
414
|
+
dataset._session, Session
|
415
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
428
416
|
|
429
417
|
transform_kwargs = dict(
|
430
|
-
session
|
431
|
-
dependencies
|
432
|
-
drop_input_cols
|
433
|
-
expected_output_cols_type
|
418
|
+
session=dataset._session,
|
419
|
+
dependencies=self._deps,
|
420
|
+
drop_input_cols=self._drop_input_cols,
|
421
|
+
expected_output_cols_type=expected_type_inferred,
|
434
422
|
)
|
435
423
|
|
436
424
|
elif isinstance(dataset, pd.DataFrame):
|
437
|
-
transform_kwargs = dict(
|
438
|
-
snowpark_input_cols = self._snowpark_cols,
|
439
|
-
drop_input_cols = self._drop_input_cols
|
440
|
-
)
|
425
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
441
426
|
|
442
427
|
transform_handlers = ModelTransformerBuilder.build(
|
443
428
|
dataset=dataset,
|
@@ -479,7 +464,7 @@ class KMeans(BaseTransformer):
|
|
479
464
|
Transformed dataset.
|
480
465
|
"""
|
481
466
|
super()._check_dataset_type(dataset)
|
482
|
-
inference_method="transform"
|
467
|
+
inference_method = "transform"
|
483
468
|
|
484
469
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
485
470
|
# are specific to the type of dataset used.
|
@@ -509,24 +494,19 @@ class KMeans(BaseTransformer):
|
|
509
494
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
510
495
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
511
496
|
|
512
|
-
self.
|
513
|
-
|
514
|
-
inference_method=inference_method,
|
515
|
-
)
|
497
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
498
|
+
self._deps = self._get_dependencies()
|
516
499
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
517
500
|
|
518
501
|
transform_kwargs = dict(
|
519
|
-
session
|
520
|
-
dependencies
|
521
|
-
drop_input_cols
|
522
|
-
expected_output_cols_type
|
502
|
+
session=dataset._session,
|
503
|
+
dependencies=self._deps,
|
504
|
+
drop_input_cols=self._drop_input_cols,
|
505
|
+
expected_output_cols_type=expected_dtype,
|
523
506
|
)
|
524
507
|
|
525
508
|
elif isinstance(dataset, pd.DataFrame):
|
526
|
-
transform_kwargs = dict(
|
527
|
-
snowpark_input_cols = self._snowpark_cols,
|
528
|
-
drop_input_cols = self._drop_input_cols
|
529
|
-
)
|
509
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
530
510
|
|
531
511
|
transform_handlers = ModelTransformerBuilder.build(
|
532
512
|
dataset=dataset,
|
@@ -545,7 +525,11 @@ class KMeans(BaseTransformer):
|
|
545
525
|
return output_df
|
546
526
|
|
547
527
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
548
|
-
def fit_predict(
|
528
|
+
def fit_predict(
|
529
|
+
self,
|
530
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
531
|
+
output_cols_prefix: str = "fit_predict_",
|
532
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
549
533
|
""" Compute cluster centers and predict cluster index for each sample
|
550
534
|
For more details on this function, see [sklearn.cluster.KMeans.fit_predict]
|
551
535
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit_predict)
|
@@ -572,22 +556,106 @@ class KMeans(BaseTransformer):
|
|
572
556
|
)
|
573
557
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
574
558
|
drop_input_cols=self._drop_input_cols,
|
575
|
-
expected_output_cols_list=
|
559
|
+
expected_output_cols_list=(
|
560
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
561
|
+
),
|
576
562
|
)
|
577
563
|
self._sklearn_object = fitted_estimator
|
578
564
|
self._is_fitted = True
|
579
565
|
return output_result
|
580
566
|
|
567
|
+
|
568
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
569
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
570
|
+
""" Compute clustering and transform X to cluster-distance space
|
571
|
+
For more details on this function, see [sklearn.cluster.KMeans.fit_transform]
|
572
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit_transform)
|
573
|
+
|
574
|
+
|
575
|
+
Raises:
|
576
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
581
577
|
|
582
|
-
|
583
|
-
|
584
|
-
|
578
|
+
Args:
|
579
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
580
|
+
Snowpark or Pandas DataFrame.
|
581
|
+
output_cols_prefix: Prefix for the response columns
|
585
582
|
Returns:
|
586
583
|
Transformed dataset.
|
587
584
|
"""
|
588
|
-
self.
|
589
|
-
|
590
|
-
|
585
|
+
self._infer_input_output_cols(dataset)
|
586
|
+
super()._check_dataset_type(dataset)
|
587
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
588
|
+
estimator=self._sklearn_object,
|
589
|
+
dataset=dataset,
|
590
|
+
input_cols=self.input_cols,
|
591
|
+
label_cols=self.label_cols,
|
592
|
+
sample_weight_col=self.sample_weight_col,
|
593
|
+
autogenerated=self._autogenerated,
|
594
|
+
subproject=_SUBPROJECT,
|
595
|
+
)
|
596
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
597
|
+
drop_input_cols=self._drop_input_cols,
|
598
|
+
expected_output_cols_list=self.output_cols,
|
599
|
+
)
|
600
|
+
self._sklearn_object = fitted_estimator
|
601
|
+
self._is_fitted = True
|
602
|
+
return output_result
|
603
|
+
|
604
|
+
|
605
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
606
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
607
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
608
|
+
"""
|
609
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
610
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
611
|
+
if output_cols:
|
612
|
+
output_cols = [
|
613
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
614
|
+
for c in output_cols
|
615
|
+
]
|
616
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
617
|
+
output_cols = [output_cols_prefix]
|
618
|
+
elif self._sklearn_object is not None:
|
619
|
+
classes = self._sklearn_object.classes_
|
620
|
+
if isinstance(classes, numpy.ndarray):
|
621
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
622
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
623
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
624
|
+
output_cols = []
|
625
|
+
for i, cl in enumerate(classes):
|
626
|
+
# For binary classification, there is only one output column for each class
|
627
|
+
# ndarray as the two classes are complementary.
|
628
|
+
if len(cl) == 2:
|
629
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
630
|
+
else:
|
631
|
+
output_cols.extend([
|
632
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
633
|
+
])
|
634
|
+
else:
|
635
|
+
output_cols = []
|
636
|
+
|
637
|
+
# Make sure column names are valid snowflake identifiers.
|
638
|
+
assert output_cols is not None # Make MyPy happy
|
639
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
640
|
+
|
641
|
+
return rv
|
642
|
+
|
643
|
+
def _align_expected_output_names(
|
644
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
645
|
+
) -> List[str]:
|
646
|
+
# in case the inferred output column names dimension is different
|
647
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
648
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
649
|
+
output_df_columns = list(output_df_pd.columns)
|
650
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
651
|
+
if self.sample_weight_col:
|
652
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
653
|
+
# if the dimension of inferred output column names is correct; use it
|
654
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
655
|
+
return expected_output_cols_list
|
656
|
+
# otherwise, use the sklearn estimator's output
|
657
|
+
else:
|
658
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
591
659
|
|
592
660
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
593
661
|
@telemetry.send_api_usage_telemetry(
|
@@ -619,24 +687,26 @@ class KMeans(BaseTransformer):
|
|
619
687
|
# are specific to the type of dataset used.
|
620
688
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
621
689
|
|
690
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
691
|
+
|
622
692
|
if isinstance(dataset, DataFrame):
|
623
|
-
self.
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
693
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
694
|
+
self._deps = self._get_dependencies()
|
695
|
+
assert isinstance(
|
696
|
+
dataset._session, Session
|
697
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
628
698
|
transform_kwargs = dict(
|
629
699
|
session=dataset._session,
|
630
700
|
dependencies=self._deps,
|
631
|
-
drop_input_cols
|
701
|
+
drop_input_cols=self._drop_input_cols,
|
632
702
|
expected_output_cols_type="float",
|
633
703
|
)
|
704
|
+
expected_output_cols = self._align_expected_output_names(
|
705
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
706
|
+
)
|
634
707
|
|
635
708
|
elif isinstance(dataset, pd.DataFrame):
|
636
|
-
transform_kwargs = dict(
|
637
|
-
snowpark_input_cols = self._snowpark_cols,
|
638
|
-
drop_input_cols = self._drop_input_cols
|
639
|
-
)
|
709
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
640
710
|
|
641
711
|
transform_handlers = ModelTransformerBuilder.build(
|
642
712
|
dataset=dataset,
|
@@ -648,7 +718,7 @@ class KMeans(BaseTransformer):
|
|
648
718
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
649
719
|
inference_method=inference_method,
|
650
720
|
input_cols=self.input_cols,
|
651
|
-
expected_output_cols=
|
721
|
+
expected_output_cols=expected_output_cols,
|
652
722
|
**transform_kwargs
|
653
723
|
)
|
654
724
|
return output_df
|
@@ -678,29 +748,30 @@ class KMeans(BaseTransformer):
|
|
678
748
|
Output dataset with log probability of the sample for each class in the model.
|
679
749
|
"""
|
680
750
|
super()._check_dataset_type(dataset)
|
681
|
-
inference_method="predict_log_proba"
|
751
|
+
inference_method = "predict_log_proba"
|
752
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
682
753
|
|
683
754
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
684
755
|
# are specific to the type of dataset used.
|
685
756
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
686
757
|
|
687
758
|
if isinstance(dataset, DataFrame):
|
688
|
-
self.
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
759
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
760
|
+
self._deps = self._get_dependencies()
|
761
|
+
assert isinstance(
|
762
|
+
dataset._session, Session
|
763
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
693
764
|
transform_kwargs = dict(
|
694
765
|
session=dataset._session,
|
695
766
|
dependencies=self._deps,
|
696
|
-
drop_input_cols
|
767
|
+
drop_input_cols=self._drop_input_cols,
|
697
768
|
expected_output_cols_type="float",
|
698
769
|
)
|
770
|
+
expected_output_cols = self._align_expected_output_names(
|
771
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
772
|
+
)
|
699
773
|
elif isinstance(dataset, pd.DataFrame):
|
700
|
-
transform_kwargs = dict(
|
701
|
-
snowpark_input_cols = self._snowpark_cols,
|
702
|
-
drop_input_cols = self._drop_input_cols
|
703
|
-
)
|
774
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
704
775
|
|
705
776
|
transform_handlers = ModelTransformerBuilder.build(
|
706
777
|
dataset=dataset,
|
@@ -713,7 +784,7 @@ class KMeans(BaseTransformer):
|
|
713
784
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
714
785
|
inference_method=inference_method,
|
715
786
|
input_cols=self.input_cols,
|
716
|
-
expected_output_cols=
|
787
|
+
expected_output_cols=expected_output_cols,
|
717
788
|
**transform_kwargs
|
718
789
|
)
|
719
790
|
return output_df
|
@@ -739,30 +810,32 @@ class KMeans(BaseTransformer):
|
|
739
810
|
Output dataset with results of the decision function for the samples in input dataset.
|
740
811
|
"""
|
741
812
|
super()._check_dataset_type(dataset)
|
742
|
-
inference_method="decision_function"
|
813
|
+
inference_method = "decision_function"
|
743
814
|
|
744
815
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
745
816
|
# are specific to the type of dataset used.
|
746
817
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
747
818
|
|
819
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
820
|
+
|
748
821
|
if isinstance(dataset, DataFrame):
|
749
|
-
self.
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
822
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
823
|
+
self._deps = self._get_dependencies()
|
824
|
+
assert isinstance(
|
825
|
+
dataset._session, Session
|
826
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
754
827
|
transform_kwargs = dict(
|
755
828
|
session=dataset._session,
|
756
829
|
dependencies=self._deps,
|
757
|
-
drop_input_cols
|
830
|
+
drop_input_cols=self._drop_input_cols,
|
758
831
|
expected_output_cols_type="float",
|
759
832
|
)
|
833
|
+
expected_output_cols = self._align_expected_output_names(
|
834
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
835
|
+
)
|
760
836
|
|
761
837
|
elif isinstance(dataset, pd.DataFrame):
|
762
|
-
transform_kwargs = dict(
|
763
|
-
snowpark_input_cols = self._snowpark_cols,
|
764
|
-
drop_input_cols = self._drop_input_cols
|
765
|
-
)
|
838
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
766
839
|
|
767
840
|
transform_handlers = ModelTransformerBuilder.build(
|
768
841
|
dataset=dataset,
|
@@ -775,7 +848,7 @@ class KMeans(BaseTransformer):
|
|
775
848
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
776
849
|
inference_method=inference_method,
|
777
850
|
input_cols=self.input_cols,
|
778
|
-
expected_output_cols=
|
851
|
+
expected_output_cols=expected_output_cols,
|
779
852
|
**transform_kwargs
|
780
853
|
)
|
781
854
|
return output_df
|
@@ -804,17 +877,17 @@ class KMeans(BaseTransformer):
|
|
804
877
|
Output dataset with probability of the sample for each class in the model.
|
805
878
|
"""
|
806
879
|
super()._check_dataset_type(dataset)
|
807
|
-
inference_method="score_samples"
|
880
|
+
inference_method = "score_samples"
|
808
881
|
|
809
882
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
810
883
|
# are specific to the type of dataset used.
|
811
884
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
812
885
|
|
886
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
887
|
+
|
813
888
|
if isinstance(dataset, DataFrame):
|
814
|
-
self.
|
815
|
-
|
816
|
-
inference_method=inference_method,
|
817
|
-
)
|
889
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
890
|
+
self._deps = self._get_dependencies()
|
818
891
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
819
892
|
transform_kwargs = dict(
|
820
893
|
session=dataset._session,
|
@@ -822,6 +895,9 @@ class KMeans(BaseTransformer):
|
|
822
895
|
drop_input_cols = self._drop_input_cols,
|
823
896
|
expected_output_cols_type="float",
|
824
897
|
)
|
898
|
+
expected_output_cols = self._align_expected_output_names(
|
899
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
900
|
+
)
|
825
901
|
|
826
902
|
elif isinstance(dataset, pd.DataFrame):
|
827
903
|
transform_kwargs = dict(
|
@@ -840,7 +916,7 @@ class KMeans(BaseTransformer):
|
|
840
916
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
841
917
|
inference_method=inference_method,
|
842
918
|
input_cols=self.input_cols,
|
843
|
-
expected_output_cols=
|
919
|
+
expected_output_cols=expected_output_cols,
|
844
920
|
**transform_kwargs
|
845
921
|
)
|
846
922
|
return output_df
|
@@ -875,17 +951,15 @@ class KMeans(BaseTransformer):
|
|
875
951
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
876
952
|
|
877
953
|
if isinstance(dataset, DataFrame):
|
878
|
-
self.
|
879
|
-
|
880
|
-
inference_method="score",
|
881
|
-
)
|
954
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
955
|
+
self._deps = self._get_dependencies()
|
882
956
|
selected_cols = self._get_active_columns()
|
883
957
|
if len(selected_cols) > 0:
|
884
958
|
dataset = dataset.select(selected_cols)
|
885
959
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
886
960
|
transform_kwargs = dict(
|
887
961
|
session=dataset._session,
|
888
|
-
dependencies=
|
962
|
+
dependencies=self._deps,
|
889
963
|
score_sproc_imports=['sklearn'],
|
890
964
|
)
|
891
965
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -950,11 +1024,8 @@ class KMeans(BaseTransformer):
|
|
950
1024
|
|
951
1025
|
if isinstance(dataset, DataFrame):
|
952
1026
|
|
953
|
-
self.
|
954
|
-
|
955
|
-
inference_method=inference_method,
|
956
|
-
|
957
|
-
)
|
1027
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1028
|
+
self._deps = self._get_dependencies()
|
958
1029
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
959
1030
|
transform_kwargs = dict(
|
960
1031
|
session = dataset._session,
|
@@ -987,50 +1058,84 @@ class KMeans(BaseTransformer):
|
|
987
1058
|
)
|
988
1059
|
return output_df
|
989
1060
|
|
1061
|
+
|
1062
|
+
|
1063
|
+
def to_sklearn(self) -> Any:
|
1064
|
+
"""Get sklearn.cluster.KMeans object.
|
1065
|
+
"""
|
1066
|
+
if self._sklearn_object is None:
|
1067
|
+
self._sklearn_object = self._create_sklearn_object()
|
1068
|
+
return self._sklearn_object
|
1069
|
+
|
1070
|
+
def to_xgboost(self) -> Any:
|
1071
|
+
raise exceptions.SnowflakeMLException(
|
1072
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1073
|
+
original_exception=AttributeError(
|
1074
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1075
|
+
"to_xgboost()",
|
1076
|
+
"to_sklearn()"
|
1077
|
+
)
|
1078
|
+
),
|
1079
|
+
)
|
990
1080
|
|
991
|
-
def
|
1081
|
+
def to_lightgbm(self) -> Any:
|
1082
|
+
raise exceptions.SnowflakeMLException(
|
1083
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1084
|
+
original_exception=AttributeError(
|
1085
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1086
|
+
"to_lightgbm()",
|
1087
|
+
"to_sklearn()"
|
1088
|
+
)
|
1089
|
+
),
|
1090
|
+
)
|
1091
|
+
|
1092
|
+
def _get_dependencies(self) -> List[str]:
|
1093
|
+
return self._deps
|
1094
|
+
|
1095
|
+
|
1096
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
992
1097
|
self._model_signature_dict = dict()
|
993
1098
|
|
994
1099
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
995
1100
|
|
996
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1101
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
997
1102
|
outputs: List[BaseFeatureSpec] = []
|
998
1103
|
if hasattr(self, "predict"):
|
999
1104
|
# keep mypy happy
|
1000
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1105
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1001
1106
|
# For classifier, the type of predict is the same as the type of label
|
1002
|
-
if self._sklearn_object._estimator_type ==
|
1003
|
-
|
1107
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1108
|
+
# label columns is the desired type for output
|
1004
1109
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1005
1110
|
# rename the output columns
|
1006
1111
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1007
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1008
|
-
|
1009
|
-
|
1112
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1113
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1114
|
+
)
|
1010
1115
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1011
1116
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1012
|
-
# Clusterer returns int64 cluster labels.
|
1117
|
+
# Clusterer returns int64 cluster labels.
|
1013
1118
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1014
1119
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1015
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1120
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1121
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1122
|
+
)
|
1123
|
+
|
1019
1124
|
# For regressor, the type of predict is float64
|
1020
|
-
elif self._sklearn_object._estimator_type ==
|
1125
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1021
1126
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1022
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1127
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1128
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1129
|
+
)
|
1130
|
+
|
1026
1131
|
for prob_func in PROB_FUNCTIONS:
|
1027
1132
|
if hasattr(self, prob_func):
|
1028
1133
|
output_cols_prefix: str = f"{prob_func}_"
|
1029
1134
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1030
1135
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1031
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1032
|
-
|
1033
|
-
|
1136
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1137
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1138
|
+
)
|
1034
1139
|
|
1035
1140
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1036
1141
|
items = list(self._model_signature_dict.items())
|
@@ -1043,10 +1148,10 @@ class KMeans(BaseTransformer):
|
|
1043
1148
|
"""Returns model signature of current class.
|
1044
1149
|
|
1045
1150
|
Raises:
|
1046
|
-
|
1151
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1047
1152
|
|
1048
1153
|
Returns:
|
1049
|
-
Dict
|
1154
|
+
Dict with each method and its input output signature
|
1050
1155
|
"""
|
1051
1156
|
if self._model_signature_dict is None:
|
1052
1157
|
raise exceptions.SnowflakeMLException(
|
@@ -1054,35 +1159,3 @@ class KMeans(BaseTransformer):
|
|
1054
1159
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1055
1160
|
)
|
1056
1161
|
return self._model_signature_dict
|
1057
|
-
|
1058
|
-
def to_sklearn(self) -> Any:
|
1059
|
-
"""Get sklearn.cluster.KMeans object.
|
1060
|
-
"""
|
1061
|
-
if self._sklearn_object is None:
|
1062
|
-
self._sklearn_object = self._create_sklearn_object()
|
1063
|
-
return self._sklearn_object
|
1064
|
-
|
1065
|
-
def to_xgboost(self) -> Any:
|
1066
|
-
raise exceptions.SnowflakeMLException(
|
1067
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1068
|
-
original_exception=AttributeError(
|
1069
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1070
|
-
"to_xgboost()",
|
1071
|
-
"to_sklearn()"
|
1072
|
-
)
|
1073
|
-
),
|
1074
|
-
)
|
1075
|
-
|
1076
|
-
def to_lightgbm(self) -> Any:
|
1077
|
-
raise exceptions.SnowflakeMLException(
|
1078
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1079
|
-
original_exception=AttributeError(
|
1080
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1081
|
-
"to_lightgbm()",
|
1082
|
-
"to_sklearn()"
|
1083
|
-
)
|
1084
|
-
),
|
1085
|
-
)
|
1086
|
-
|
1087
|
-
def _get_dependencies(self) -> List[str]:
|
1088
|
-
return self._deps
|