snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
33
33
|
BatchInferenceKwargsTypedDict,
|
34
34
|
ScoreKwargsTypedDict
|
35
35
|
)
|
36
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
|
+
from snowflake.ml.model.model_signature import (
|
38
|
+
BaseFeatureSpec,
|
39
|
+
DataType,
|
40
|
+
FeatureSpec,
|
41
|
+
ModelSignature,
|
42
|
+
_infer_signature,
|
43
|
+
_rename_signature_with_snowflake_identifiers,
|
44
|
+
)
|
36
45
|
|
37
46
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
38
47
|
|
@@ -43,16 +52,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
43
52
|
validate_sklearn_args,
|
44
53
|
)
|
45
54
|
|
46
|
-
from snowflake.ml.model.model_signature import (
|
47
|
-
DataType,
|
48
|
-
FeatureSpec,
|
49
|
-
ModelSignature,
|
50
|
-
_infer_signature,
|
51
|
-
_rename_signature_with_snowflake_identifiers,
|
52
|
-
BaseFeatureSpec,
|
53
|
-
)
|
54
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
55
|
-
|
56
55
|
_PROJECT = "ModelDevelopment"
|
57
56
|
# Derive subproject from module name by removing "sklearn"
|
58
57
|
# and converting module name from underscore to CamelCase
|
@@ -61,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
|
|
61
60
|
|
62
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
62
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
63
|
class MiniBatchKMeans(BaseTransformer):
|
71
64
|
r"""Mini-Batch K-Means clustering
|
72
65
|
For more details on this class, see [sklearn.cluster.MiniBatchKMeans]
|
@@ -303,12 +296,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
303
296
|
)
|
304
297
|
return selected_cols
|
305
298
|
|
306
|
-
|
307
|
-
project=_PROJECT,
|
308
|
-
subproject=_SUBPROJECT,
|
309
|
-
custom_tags=dict([("autogen", True)]),
|
310
|
-
)
|
311
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MiniBatchKMeans":
|
299
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "MiniBatchKMeans":
|
312
300
|
"""Compute the centroids on X by chunking it into mini-batches
|
313
301
|
For more details on this function, see [sklearn.cluster.MiniBatchKMeans.fit]
|
314
302
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html#sklearn.cluster.MiniBatchKMeans.fit)
|
@@ -335,12 +323,14 @@ class MiniBatchKMeans(BaseTransformer):
|
|
335
323
|
|
336
324
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
337
325
|
|
338
|
-
|
326
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
339
327
|
if SNOWML_SPROC_ENV in os.environ:
|
340
328
|
statement_params = telemetry.get_function_usage_statement_params(
|
341
329
|
project=_PROJECT,
|
342
330
|
subproject=_SUBPROJECT,
|
343
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
331
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
332
|
+
inspect.currentframe(), MiniBatchKMeans.__class__.__name__
|
333
|
+
),
|
344
334
|
api_calls=[Session.call],
|
345
335
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
346
336
|
)
|
@@ -361,27 +351,24 @@ class MiniBatchKMeans(BaseTransformer):
|
|
361
351
|
)
|
362
352
|
self._sklearn_object = model_trainer.train()
|
363
353
|
self._is_fitted = True
|
364
|
-
self.
|
354
|
+
self._generate_model_signatures(dataset)
|
365
355
|
return self
|
366
356
|
|
367
357
|
def _batch_inference_validate_snowpark(
|
368
358
|
self,
|
369
359
|
dataset: DataFrame,
|
370
360
|
inference_method: str,
|
371
|
-
) ->
|
372
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
373
|
-
return the available package that exists in the snowflake anaconda channel
|
361
|
+
) -> None:
|
362
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
374
363
|
|
375
364
|
Args:
|
376
365
|
dataset: snowpark dataframe
|
377
366
|
inference_method: the inference method such as predict, score...
|
378
|
-
|
367
|
+
|
379
368
|
Raises:
|
380
369
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
381
370
|
SnowflakeMLException: If the session is None, raise error
|
382
371
|
|
383
|
-
Returns:
|
384
|
-
A list of available package that exists in the snowflake anaconda channel
|
385
372
|
"""
|
386
373
|
if not self._is_fitted:
|
387
374
|
raise exceptions.SnowflakeMLException(
|
@@ -399,9 +386,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
399
386
|
"Session must not specified for snowpark dataset."
|
400
387
|
),
|
401
388
|
)
|
402
|
-
|
403
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
404
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
389
|
+
|
405
390
|
|
406
391
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
407
392
|
@telemetry.send_api_usage_telemetry(
|
@@ -437,7 +422,9 @@ class MiniBatchKMeans(BaseTransformer):
|
|
437
422
|
# when it is classifier, infer the datatype from label columns
|
438
423
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
439
424
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
440
|
-
label_cols_signatures = [
|
425
|
+
label_cols_signatures = [
|
426
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
427
|
+
]
|
441
428
|
if len(label_cols_signatures) == 0:
|
442
429
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
443
430
|
raise exceptions.SnowflakeMLException(
|
@@ -445,25 +432,23 @@ class MiniBatchKMeans(BaseTransformer):
|
|
445
432
|
original_exception=ValueError(error_str),
|
446
433
|
)
|
447
434
|
|
448
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
449
|
-
label_cols_signatures[0].as_snowpark_type()
|
450
|
-
)
|
435
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
451
436
|
|
452
|
-
self.
|
453
|
-
|
437
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
438
|
+
self._deps = self._get_dependencies()
|
439
|
+
assert isinstance(
|
440
|
+
dataset._session, Session
|
441
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
454
442
|
|
455
443
|
transform_kwargs = dict(
|
456
|
-
session
|
457
|
-
dependencies
|
458
|
-
drop_input_cols
|
459
|
-
expected_output_cols_type
|
444
|
+
session=dataset._session,
|
445
|
+
dependencies=self._deps,
|
446
|
+
drop_input_cols=self._drop_input_cols,
|
447
|
+
expected_output_cols_type=expected_type_inferred,
|
460
448
|
)
|
461
449
|
|
462
450
|
elif isinstance(dataset, pd.DataFrame):
|
463
|
-
transform_kwargs = dict(
|
464
|
-
snowpark_input_cols = self._snowpark_cols,
|
465
|
-
drop_input_cols = self._drop_input_cols
|
466
|
-
)
|
451
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
467
452
|
|
468
453
|
transform_handlers = ModelTransformerBuilder.build(
|
469
454
|
dataset=dataset,
|
@@ -505,7 +490,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
505
490
|
Transformed dataset.
|
506
491
|
"""
|
507
492
|
super()._check_dataset_type(dataset)
|
508
|
-
inference_method="transform"
|
493
|
+
inference_method = "transform"
|
509
494
|
|
510
495
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
511
496
|
# are specific to the type of dataset used.
|
@@ -535,24 +520,19 @@ class MiniBatchKMeans(BaseTransformer):
|
|
535
520
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
536
521
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
537
522
|
|
538
|
-
self.
|
539
|
-
|
540
|
-
inference_method=inference_method,
|
541
|
-
)
|
523
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
524
|
+
self._deps = self._get_dependencies()
|
542
525
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
543
526
|
|
544
527
|
transform_kwargs = dict(
|
545
|
-
session
|
546
|
-
dependencies
|
547
|
-
drop_input_cols
|
548
|
-
expected_output_cols_type
|
528
|
+
session=dataset._session,
|
529
|
+
dependencies=self._deps,
|
530
|
+
drop_input_cols=self._drop_input_cols,
|
531
|
+
expected_output_cols_type=expected_dtype,
|
549
532
|
)
|
550
533
|
|
551
534
|
elif isinstance(dataset, pd.DataFrame):
|
552
|
-
transform_kwargs = dict(
|
553
|
-
snowpark_input_cols = self._snowpark_cols,
|
554
|
-
drop_input_cols = self._drop_input_cols
|
555
|
-
)
|
535
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
556
536
|
|
557
537
|
transform_handlers = ModelTransformerBuilder.build(
|
558
538
|
dataset=dataset,
|
@@ -571,7 +551,11 @@ class MiniBatchKMeans(BaseTransformer):
|
|
571
551
|
return output_df
|
572
552
|
|
573
553
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
574
|
-
def fit_predict(
|
554
|
+
def fit_predict(
|
555
|
+
self,
|
556
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
557
|
+
output_cols_prefix: str = "fit_predict_",
|
558
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
575
559
|
""" Compute cluster centers and predict cluster index for each sample
|
576
560
|
For more details on this function, see [sklearn.cluster.MiniBatchKMeans.fit_predict]
|
577
561
|
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html#sklearn.cluster.MiniBatchKMeans.fit_predict)
|
@@ -598,22 +582,106 @@ class MiniBatchKMeans(BaseTransformer):
|
|
598
582
|
)
|
599
583
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
600
584
|
drop_input_cols=self._drop_input_cols,
|
601
|
-
expected_output_cols_list=
|
585
|
+
expected_output_cols_list=(
|
586
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
587
|
+
),
|
602
588
|
)
|
603
589
|
self._sklearn_object = fitted_estimator
|
604
590
|
self._is_fitted = True
|
605
591
|
return output_result
|
606
592
|
|
593
|
+
|
594
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
595
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
596
|
+
""" Compute clustering and transform X to cluster-distance space
|
597
|
+
For more details on this function, see [sklearn.cluster.MiniBatchKMeans.fit_transform]
|
598
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html#sklearn.cluster.MiniBatchKMeans.fit_transform)
|
599
|
+
|
600
|
+
|
601
|
+
Raises:
|
602
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
607
603
|
|
608
|
-
|
609
|
-
|
610
|
-
|
604
|
+
Args:
|
605
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
606
|
+
Snowpark or Pandas DataFrame.
|
607
|
+
output_cols_prefix: Prefix for the response columns
|
611
608
|
Returns:
|
612
609
|
Transformed dataset.
|
613
610
|
"""
|
614
|
-
self.
|
615
|
-
|
616
|
-
|
611
|
+
self._infer_input_output_cols(dataset)
|
612
|
+
super()._check_dataset_type(dataset)
|
613
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
614
|
+
estimator=self._sklearn_object,
|
615
|
+
dataset=dataset,
|
616
|
+
input_cols=self.input_cols,
|
617
|
+
label_cols=self.label_cols,
|
618
|
+
sample_weight_col=self.sample_weight_col,
|
619
|
+
autogenerated=self._autogenerated,
|
620
|
+
subproject=_SUBPROJECT,
|
621
|
+
)
|
622
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
623
|
+
drop_input_cols=self._drop_input_cols,
|
624
|
+
expected_output_cols_list=self.output_cols,
|
625
|
+
)
|
626
|
+
self._sklearn_object = fitted_estimator
|
627
|
+
self._is_fitted = True
|
628
|
+
return output_result
|
629
|
+
|
630
|
+
|
631
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
632
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
633
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
634
|
+
"""
|
635
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
636
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
637
|
+
if output_cols:
|
638
|
+
output_cols = [
|
639
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
640
|
+
for c in output_cols
|
641
|
+
]
|
642
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
643
|
+
output_cols = [output_cols_prefix]
|
644
|
+
elif self._sklearn_object is not None:
|
645
|
+
classes = self._sklearn_object.classes_
|
646
|
+
if isinstance(classes, numpy.ndarray):
|
647
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
648
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
649
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
650
|
+
output_cols = []
|
651
|
+
for i, cl in enumerate(classes):
|
652
|
+
# For binary classification, there is only one output column for each class
|
653
|
+
# ndarray as the two classes are complementary.
|
654
|
+
if len(cl) == 2:
|
655
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
656
|
+
else:
|
657
|
+
output_cols.extend([
|
658
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
659
|
+
])
|
660
|
+
else:
|
661
|
+
output_cols = []
|
662
|
+
|
663
|
+
# Make sure column names are valid snowflake identifiers.
|
664
|
+
assert output_cols is not None # Make MyPy happy
|
665
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
666
|
+
|
667
|
+
return rv
|
668
|
+
|
669
|
+
def _align_expected_output_names(
|
670
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
671
|
+
) -> List[str]:
|
672
|
+
# in case the inferred output column names dimension is different
|
673
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
674
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
675
|
+
output_df_columns = list(output_df_pd.columns)
|
676
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
677
|
+
if self.sample_weight_col:
|
678
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
679
|
+
# if the dimension of inferred output column names is correct; use it
|
680
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
681
|
+
return expected_output_cols_list
|
682
|
+
# otherwise, use the sklearn estimator's output
|
683
|
+
else:
|
684
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
617
685
|
|
618
686
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
619
687
|
@telemetry.send_api_usage_telemetry(
|
@@ -645,24 +713,26 @@ class MiniBatchKMeans(BaseTransformer):
|
|
645
713
|
# are specific to the type of dataset used.
|
646
714
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
647
715
|
|
716
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
717
|
+
|
648
718
|
if isinstance(dataset, DataFrame):
|
649
|
-
self.
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
719
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
720
|
+
self._deps = self._get_dependencies()
|
721
|
+
assert isinstance(
|
722
|
+
dataset._session, Session
|
723
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
654
724
|
transform_kwargs = dict(
|
655
725
|
session=dataset._session,
|
656
726
|
dependencies=self._deps,
|
657
|
-
drop_input_cols
|
727
|
+
drop_input_cols=self._drop_input_cols,
|
658
728
|
expected_output_cols_type="float",
|
659
729
|
)
|
730
|
+
expected_output_cols = self._align_expected_output_names(
|
731
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
732
|
+
)
|
660
733
|
|
661
734
|
elif isinstance(dataset, pd.DataFrame):
|
662
|
-
transform_kwargs = dict(
|
663
|
-
snowpark_input_cols = self._snowpark_cols,
|
664
|
-
drop_input_cols = self._drop_input_cols
|
665
|
-
)
|
735
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
666
736
|
|
667
737
|
transform_handlers = ModelTransformerBuilder.build(
|
668
738
|
dataset=dataset,
|
@@ -674,7 +744,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
674
744
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
675
745
|
inference_method=inference_method,
|
676
746
|
input_cols=self.input_cols,
|
677
|
-
expected_output_cols=
|
747
|
+
expected_output_cols=expected_output_cols,
|
678
748
|
**transform_kwargs
|
679
749
|
)
|
680
750
|
return output_df
|
@@ -704,29 +774,30 @@ class MiniBatchKMeans(BaseTransformer):
|
|
704
774
|
Output dataset with log probability of the sample for each class in the model.
|
705
775
|
"""
|
706
776
|
super()._check_dataset_type(dataset)
|
707
|
-
inference_method="predict_log_proba"
|
777
|
+
inference_method = "predict_log_proba"
|
778
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
708
779
|
|
709
780
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
710
781
|
# are specific to the type of dataset used.
|
711
782
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
712
783
|
|
713
784
|
if isinstance(dataset, DataFrame):
|
714
|
-
self.
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
785
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
786
|
+
self._deps = self._get_dependencies()
|
787
|
+
assert isinstance(
|
788
|
+
dataset._session, Session
|
789
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
719
790
|
transform_kwargs = dict(
|
720
791
|
session=dataset._session,
|
721
792
|
dependencies=self._deps,
|
722
|
-
drop_input_cols
|
793
|
+
drop_input_cols=self._drop_input_cols,
|
723
794
|
expected_output_cols_type="float",
|
724
795
|
)
|
796
|
+
expected_output_cols = self._align_expected_output_names(
|
797
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
798
|
+
)
|
725
799
|
elif isinstance(dataset, pd.DataFrame):
|
726
|
-
transform_kwargs = dict(
|
727
|
-
snowpark_input_cols = self._snowpark_cols,
|
728
|
-
drop_input_cols = self._drop_input_cols
|
729
|
-
)
|
800
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
730
801
|
|
731
802
|
transform_handlers = ModelTransformerBuilder.build(
|
732
803
|
dataset=dataset,
|
@@ -739,7 +810,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
739
810
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
740
811
|
inference_method=inference_method,
|
741
812
|
input_cols=self.input_cols,
|
742
|
-
expected_output_cols=
|
813
|
+
expected_output_cols=expected_output_cols,
|
743
814
|
**transform_kwargs
|
744
815
|
)
|
745
816
|
return output_df
|
@@ -765,30 +836,32 @@ class MiniBatchKMeans(BaseTransformer):
|
|
765
836
|
Output dataset with results of the decision function for the samples in input dataset.
|
766
837
|
"""
|
767
838
|
super()._check_dataset_type(dataset)
|
768
|
-
inference_method="decision_function"
|
839
|
+
inference_method = "decision_function"
|
769
840
|
|
770
841
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
771
842
|
# are specific to the type of dataset used.
|
772
843
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
773
844
|
|
845
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
846
|
+
|
774
847
|
if isinstance(dataset, DataFrame):
|
775
|
-
self.
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
848
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
849
|
+
self._deps = self._get_dependencies()
|
850
|
+
assert isinstance(
|
851
|
+
dataset._session, Session
|
852
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
780
853
|
transform_kwargs = dict(
|
781
854
|
session=dataset._session,
|
782
855
|
dependencies=self._deps,
|
783
|
-
drop_input_cols
|
856
|
+
drop_input_cols=self._drop_input_cols,
|
784
857
|
expected_output_cols_type="float",
|
785
858
|
)
|
859
|
+
expected_output_cols = self._align_expected_output_names(
|
860
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
861
|
+
)
|
786
862
|
|
787
863
|
elif isinstance(dataset, pd.DataFrame):
|
788
|
-
transform_kwargs = dict(
|
789
|
-
snowpark_input_cols = self._snowpark_cols,
|
790
|
-
drop_input_cols = self._drop_input_cols
|
791
|
-
)
|
864
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
792
865
|
|
793
866
|
transform_handlers = ModelTransformerBuilder.build(
|
794
867
|
dataset=dataset,
|
@@ -801,7 +874,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
801
874
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
802
875
|
inference_method=inference_method,
|
803
876
|
input_cols=self.input_cols,
|
804
|
-
expected_output_cols=
|
877
|
+
expected_output_cols=expected_output_cols,
|
805
878
|
**transform_kwargs
|
806
879
|
)
|
807
880
|
return output_df
|
@@ -830,17 +903,17 @@ class MiniBatchKMeans(BaseTransformer):
|
|
830
903
|
Output dataset with probability of the sample for each class in the model.
|
831
904
|
"""
|
832
905
|
super()._check_dataset_type(dataset)
|
833
|
-
inference_method="score_samples"
|
906
|
+
inference_method = "score_samples"
|
834
907
|
|
835
908
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
836
909
|
# are specific to the type of dataset used.
|
837
910
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
838
911
|
|
912
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
913
|
+
|
839
914
|
if isinstance(dataset, DataFrame):
|
840
|
-
self.
|
841
|
-
|
842
|
-
inference_method=inference_method,
|
843
|
-
)
|
915
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
916
|
+
self._deps = self._get_dependencies()
|
844
917
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
845
918
|
transform_kwargs = dict(
|
846
919
|
session=dataset._session,
|
@@ -848,6 +921,9 @@ class MiniBatchKMeans(BaseTransformer):
|
|
848
921
|
drop_input_cols = self._drop_input_cols,
|
849
922
|
expected_output_cols_type="float",
|
850
923
|
)
|
924
|
+
expected_output_cols = self._align_expected_output_names(
|
925
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
926
|
+
)
|
851
927
|
|
852
928
|
elif isinstance(dataset, pd.DataFrame):
|
853
929
|
transform_kwargs = dict(
|
@@ -866,7 +942,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
866
942
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
867
943
|
inference_method=inference_method,
|
868
944
|
input_cols=self.input_cols,
|
869
|
-
expected_output_cols=
|
945
|
+
expected_output_cols=expected_output_cols,
|
870
946
|
**transform_kwargs
|
871
947
|
)
|
872
948
|
return output_df
|
@@ -901,17 +977,15 @@ class MiniBatchKMeans(BaseTransformer):
|
|
901
977
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
902
978
|
|
903
979
|
if isinstance(dataset, DataFrame):
|
904
|
-
self.
|
905
|
-
|
906
|
-
inference_method="score",
|
907
|
-
)
|
980
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
981
|
+
self._deps = self._get_dependencies()
|
908
982
|
selected_cols = self._get_active_columns()
|
909
983
|
if len(selected_cols) > 0:
|
910
984
|
dataset = dataset.select(selected_cols)
|
911
985
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
912
986
|
transform_kwargs = dict(
|
913
987
|
session=dataset._session,
|
914
|
-
dependencies=
|
988
|
+
dependencies=self._deps,
|
915
989
|
score_sproc_imports=['sklearn'],
|
916
990
|
)
|
917
991
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -976,11 +1050,8 @@ class MiniBatchKMeans(BaseTransformer):
|
|
976
1050
|
|
977
1051
|
if isinstance(dataset, DataFrame):
|
978
1052
|
|
979
|
-
self.
|
980
|
-
|
981
|
-
inference_method=inference_method,
|
982
|
-
|
983
|
-
)
|
1053
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1054
|
+
self._deps = self._get_dependencies()
|
984
1055
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
985
1056
|
transform_kwargs = dict(
|
986
1057
|
session = dataset._session,
|
@@ -1013,50 +1084,84 @@ class MiniBatchKMeans(BaseTransformer):
|
|
1013
1084
|
)
|
1014
1085
|
return output_df
|
1015
1086
|
|
1087
|
+
|
1088
|
+
|
1089
|
+
def to_sklearn(self) -> Any:
|
1090
|
+
"""Get sklearn.cluster.MiniBatchKMeans object.
|
1091
|
+
"""
|
1092
|
+
if self._sklearn_object is None:
|
1093
|
+
self._sklearn_object = self._create_sklearn_object()
|
1094
|
+
return self._sklearn_object
|
1095
|
+
|
1096
|
+
def to_xgboost(self) -> Any:
|
1097
|
+
raise exceptions.SnowflakeMLException(
|
1098
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1099
|
+
original_exception=AttributeError(
|
1100
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1101
|
+
"to_xgboost()",
|
1102
|
+
"to_sklearn()"
|
1103
|
+
)
|
1104
|
+
),
|
1105
|
+
)
|
1016
1106
|
|
1017
|
-
def
|
1107
|
+
def to_lightgbm(self) -> Any:
|
1108
|
+
raise exceptions.SnowflakeMLException(
|
1109
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1110
|
+
original_exception=AttributeError(
|
1111
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1112
|
+
"to_lightgbm()",
|
1113
|
+
"to_sklearn()"
|
1114
|
+
)
|
1115
|
+
),
|
1116
|
+
)
|
1117
|
+
|
1118
|
+
def _get_dependencies(self) -> List[str]:
|
1119
|
+
return self._deps
|
1120
|
+
|
1121
|
+
|
1122
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
1018
1123
|
self._model_signature_dict = dict()
|
1019
1124
|
|
1020
1125
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
1021
1126
|
|
1022
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1127
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
1023
1128
|
outputs: List[BaseFeatureSpec] = []
|
1024
1129
|
if hasattr(self, "predict"):
|
1025
1130
|
# keep mypy happy
|
1026
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1131
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1027
1132
|
# For classifier, the type of predict is the same as the type of label
|
1028
|
-
if self._sklearn_object._estimator_type ==
|
1029
|
-
|
1133
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1134
|
+
# label columns is the desired type for output
|
1030
1135
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
1031
1136
|
# rename the output columns
|
1032
1137
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
1033
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1034
|
-
|
1035
|
-
|
1138
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1139
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1140
|
+
)
|
1036
1141
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
1037
1142
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
1038
|
-
# Clusterer returns int64 cluster labels.
|
1143
|
+
# Clusterer returns int64 cluster labels.
|
1039
1144
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
1040
1145
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
1041
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1146
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1147
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1148
|
+
)
|
1149
|
+
|
1045
1150
|
# For regressor, the type of predict is float64
|
1046
|
-
elif self._sklearn_object._estimator_type ==
|
1151
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
1047
1152
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
1048
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1153
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1154
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1155
|
+
)
|
1156
|
+
|
1052
1157
|
for prob_func in PROB_FUNCTIONS:
|
1053
1158
|
if hasattr(self, prob_func):
|
1054
1159
|
output_cols_prefix: str = f"{prob_func}_"
|
1055
1160
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
1056
1161
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
1057
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
1058
|
-
|
1059
|
-
|
1162
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1163
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1164
|
+
)
|
1060
1165
|
|
1061
1166
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
1062
1167
|
items = list(self._model_signature_dict.items())
|
@@ -1069,10 +1174,10 @@ class MiniBatchKMeans(BaseTransformer):
|
|
1069
1174
|
"""Returns model signature of current class.
|
1070
1175
|
|
1071
1176
|
Raises:
|
1072
|
-
|
1177
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
1073
1178
|
|
1074
1179
|
Returns:
|
1075
|
-
Dict
|
1180
|
+
Dict with each method and its input output signature
|
1076
1181
|
"""
|
1077
1182
|
if self._model_signature_dict is None:
|
1078
1183
|
raise exceptions.SnowflakeMLException(
|
@@ -1080,35 +1185,3 @@ class MiniBatchKMeans(BaseTransformer):
|
|
1080
1185
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
1081
1186
|
)
|
1082
1187
|
return self._model_signature_dict
|
1083
|
-
|
1084
|
-
def to_sklearn(self) -> Any:
|
1085
|
-
"""Get sklearn.cluster.MiniBatchKMeans object.
|
1086
|
-
"""
|
1087
|
-
if self._sklearn_object is None:
|
1088
|
-
self._sklearn_object = self._create_sklearn_object()
|
1089
|
-
return self._sklearn_object
|
1090
|
-
|
1091
|
-
def to_xgboost(self) -> Any:
|
1092
|
-
raise exceptions.SnowflakeMLException(
|
1093
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1094
|
-
original_exception=AttributeError(
|
1095
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1096
|
-
"to_xgboost()",
|
1097
|
-
"to_sklearn()"
|
1098
|
-
)
|
1099
|
-
),
|
1100
|
-
)
|
1101
|
-
|
1102
|
-
def to_lightgbm(self) -> Any:
|
1103
|
-
raise exceptions.SnowflakeMLException(
|
1104
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1105
|
-
original_exception=AttributeError(
|
1106
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1107
|
-
"to_lightgbm()",
|
1108
|
-
"to_sklearn()"
|
1109
|
-
)
|
1110
|
-
),
|
1111
|
-
)
|
1112
|
-
|
1113
|
-
def _get_dependencies(self) -> List[str]:
|
1114
|
-
return self._deps
|